Repository: Azure/azure-linux-extensions Branch: master Commit: 7572dd33197a Files: 991 Total size: 6.8 MB Directory structure: gitextract_55atar1w/ ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .vscode/ │ └── launch.json ├── AzureEnhancedMonitor/ │ ├── README.md │ ├── bin/ │ │ ├── pack.sh │ │ └── setup.sh │ ├── clib/ │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── include/ │ │ │ └── azureperf.h │ │ ├── src/ │ │ │ ├── apmetric.c │ │ │ └── azureperf.c │ │ └── test/ │ │ ├── cases/ │ │ │ └── positive_case │ │ ├── codegen.py │ │ ├── counter_names │ │ └── runtest.c │ ├── ext/ │ │ ├── .gitignore │ │ ├── HandlerManifest.json │ │ ├── aem.py │ │ ├── handler.py │ │ ├── installer.py │ │ ├── references │ │ └── test/ │ │ ├── env.py │ │ ├── storage_metrics │ │ ├── test_aem.py │ │ └── test_installer.py │ ├── hvinfo/ │ │ ├── .gitignore │ │ ├── Makefile │ │ └── src/ │ │ └── hvinfo.c │ └── nodejs/ │ ├── package.json │ └── setaem.js ├── AzureMonitorAgent/ │ ├── .gitignore │ ├── HandlerManifest.json │ ├── README.md │ ├── agent.py │ ├── agent.version │ ├── ama_tst/ │ │ ├── AMA-Troubleshooting-Tool.md │ │ ├── __init__.py │ │ ├── ama_troubleshooter.sh │ │ └── modules/ │ │ ├── __init__.py │ │ ├── connect/ │ │ │ ├── __init__.py │ │ │ ├── check_endpts.py │ │ │ ├── check_imds.py │ │ │ └── connect.py │ │ ├── custom_logs/ │ │ │ ├── __init__.py │ │ │ ├── check_clconf.py │ │ │ └── custom_logs.py │ │ ├── error_codes.py │ │ ├── errors.py │ │ ├── general_health/ │ │ │ ├── __init__.py │ │ │ ├── check_status.py │ │ │ └── general_health.py │ │ ├── helpers.py │ │ ├── high_cpu_mem/ │ │ │ ├── __init__.py │ │ │ ├── check_logrot.py │ │ │ ├── check_usage.py │ │ │ └── high_cpu_mem.py │ │ ├── install/ │ │ │ ├── __init__.py │ │ │ ├── check_ama.py │ │ │ ├── check_os.py │ │ │ ├── check_pkgs.py │ │ │ ├── install.py │ │ │ └── supported_distros.py │ │ ├── logcollector.py │ │ ├── main.py │ │ ├── metrics_troubleshooter/ │ │ │ ├── __init__.py │ │ │ └── metrics_troubleshooter.py │ │ └── syslog_tst/ │ │ ├── __init__.py │ │ ├── check_conf.py │ │ ├── check_rsysng.py │ │ └── syslog.py │ ├── apply_version.sh │ ├── azuremonitoragentextension.logrotate │ ├── manifest.xml │ ├── packaging.sh │ ├── references │ ├── services/ │ │ ├── metrics-extension-cmv1.service │ │ ├── metrics-extension-cmv2.service │ │ └── metrics-extension-otlp.service │ ├── shim.sh │ └── update_version.sh ├── CODEOWNERS ├── Common/ │ ├── WALinuxAgent-2.0.14/ │ │ └── waagent │ ├── WALinuxAgent-2.0.16/ │ │ └── waagent │ ├── libpsutil/ │ │ ├── py2.6-glibc-2.12-pre/ │ │ │ └── psutil/ │ │ │ ├── __init__.py │ │ │ ├── _common.py │ │ │ ├── _compat.py │ │ │ ├── _psbsd.py │ │ │ ├── _pslinux.py │ │ │ ├── _psosx.py │ │ │ ├── _psposix.py │ │ │ ├── _pssunos.py │ │ │ └── _pswindows.py │ │ └── py2.7-glibc-2.12+/ │ │ └── psutil/ │ │ ├── __init__.py │ │ ├── _common.py │ │ ├── _compat.py │ │ ├── _psbsd.py │ │ ├── _pslinux.py │ │ ├── _psosx.py │ │ ├── _psposix.py │ │ ├── _pssunos.py │ │ └── _pswindows.py │ └── waagentloader.py ├── CustomScript/ │ ├── CHANGELOG.md │ ├── HandlerManifest.json │ ├── README.md │ ├── azure/ │ │ ├── __init__.py │ │ ├── azure.pyproj │ │ ├── http/ │ │ │ ├── __init__.py │ │ │ ├── batchclient.py │ │ │ ├── httpclient.py │ │ │ └── winhttp.py │ │ ├── servicebus/ │ │ │ ├── __init__.py │ │ │ └── servicebusservice.py │ │ ├── servicemanagement/ │ │ │ ├── __init__.py │ │ │ ├── servicebusmanagementservice.py │ │ │ ├── servicemanagementclient.py │ │ │ ├── servicemanagementservice.py │ │ │ ├── sqldatabasemanagementservice.py │ │ │ └── websitemanagementservice.py │ │ └── storage/ │ │ ├── __init__.py │ │ ├── blobservice.py │ │ ├── cloudstorageaccount.py │ │ ├── queueservice.py │ │ ├── sharedaccesssignature.py │ │ ├── storageclient.py │ │ └── tableservice.py │ ├── customscript.py │ ├── manifest.xml │ ├── references │ ├── shim.sh │ └── test/ │ ├── HandlerEnvironment.json │ ├── MockUtil.py │ ├── create_test_blob.py │ ├── env.py │ ├── run_all.sh │ ├── test_blob_download.py │ ├── test_file_download.py │ ├── test_preprocess_file.py │ ├── test_uri_utils.py │ └── timeout.sh ├── DSC/ │ ├── HandlerManifest.json │ ├── Makefile │ ├── README.md │ ├── azure/ │ │ ├── __init__.py │ │ ├── azure.pyproj │ │ ├── http/ │ │ │ ├── __init__.py │ │ │ ├── batchclient.py │ │ │ ├── httpclient.py │ │ │ └── winhttp.py │ │ ├── servicebus/ │ │ │ ├── __init__.py │ │ │ └── servicebusservice.py │ │ ├── servicemanagement/ │ │ │ ├── __init__.py │ │ │ ├── servicebusmanagementservice.py │ │ │ ├── servicemanagementclient.py │ │ │ ├── servicemanagementservice.py │ │ │ ├── sqldatabasemanagementservice.py │ │ │ └── websitemanagementservice.py │ │ └── storage/ │ │ ├── __init__.py │ │ ├── blobservice.py │ │ ├── cloudstorageaccount.py │ │ ├── queueservice.py │ │ ├── sharedaccesssignature.py │ │ ├── storageclient.py │ │ └── tableservice.py │ ├── curlhttpclient.py │ ├── dsc.py │ ├── extension_shim.sh │ ├── httpclient.py │ ├── httpclientfactory.py │ ├── manifest.xml │ ├── packages/ │ │ ├── dsc-1.2.4-0.ssl_100.x64.deb │ │ ├── dsc-1.2.4-0.ssl_100.x64.rpm │ │ ├── dsc-1.2.4-0.ssl_110.x64.deb │ │ ├── dsc-1.2.4-0.ssl_110.x64.rpm │ │ ├── omi-1.7.3-0.ssl_100.ulinux.s.x64.deb │ │ ├── omi-1.7.3-0.ssl_100.ulinux.s.x64.rpm │ │ ├── omi-1.7.3-0.ssl_100.ulinux.x64.deb │ │ ├── omi-1.7.3-0.ssl_100.ulinux.x64.rpm │ │ ├── omi-1.7.3-0.ssl_110.ulinux.s.x64.deb │ │ ├── omi-1.7.3-0.ssl_110.ulinux.s.x64.rpm │ │ ├── omi-1.7.3-0.ssl_110.ulinux.x64.deb │ │ └── omi-1.7.3-0.ssl_110.ulinux.x64.rpm │ ├── serializerfactory.py │ ├── subprocessfactory.py │ ├── test/ │ │ ├── MockUtil.py │ │ ├── env.py │ │ ├── mof/ │ │ │ ├── azureautomation.df.meta.mof │ │ │ ├── dscnode.nxFile.meta.mof │ │ │ ├── dscnode.nxFile.meta.push.mof │ │ │ └── localhost.nxFile.mof │ │ ├── status/ │ │ │ └── 0.status │ │ ├── test_apply_meta_mof.py │ │ ├── test_apply_mof.py │ │ ├── test_compare_pkg_version.py │ │ ├── test_download_file.py │ │ ├── test_node_extension_properties.py │ │ ├── test_register.py │ │ └── test_status_update.py │ ├── urllib2httpclient.py │ └── urllib3httpclient.py ├── Diagnostic/ │ ├── ChangeLogs │ ├── DistroSpecific.py │ ├── HandlerManifest.json │ ├── Makefile │ ├── Providers/ │ │ ├── Builtin.py │ │ └── __init__.py │ ├── README.md │ ├── Utils/ │ │ ├── LadDiagnosticUtil.py │ │ ├── ProviderUtil.py │ │ ├── XmlUtil.py │ │ ├── __init__.py │ │ ├── imds_util.py │ │ ├── lad_exceptions.py │ │ ├── lad_ext_settings.py │ │ ├── lad_logging_config.py │ │ ├── mdsd_xml_templates.py │ │ ├── misc_helpers.py │ │ └── omsagent_util.py │ ├── __init__.py │ ├── decrypt_protected_settings.sh │ ├── diagnostic.py │ ├── lad_config_all.py │ ├── lad_mdsd.te │ ├── license.txt │ ├── manifest.xml │ ├── mdsd/ │ │ ├── CMakeLists.txt │ │ ├── Dockerfile │ │ ├── LICENSE.txt │ │ ├── README.md │ │ ├── SampleConfig-LAD-SAS.xml │ │ ├── azure.list │ │ ├── buildcmake.sh │ │ ├── lad-mdsd/ │ │ │ ├── Makefile.in.version │ │ │ ├── README.txt │ │ │ ├── changelog │ │ │ ├── copyright │ │ │ ├── deb/ │ │ │ │ ├── Makefile │ │ │ │ └── control │ │ │ └── rpm/ │ │ │ └── Makefile │ │ ├── mdscommands/ │ │ │ ├── BinaryWriter.hh │ │ │ ├── BodyOnlyXmlParser.cc │ │ │ ├── BodyOnlyXmlParser.hh │ │ │ ├── CMakeLists.txt │ │ │ ├── CmdListXmlParser.cc │ │ │ ├── CmdListXmlParser.hh │ │ │ ├── CmdXmlCommon.cc │ │ │ ├── CmdXmlCommon.hh │ │ │ ├── CmdXmlElement.cc │ │ │ ├── CmdXmlElement.hh │ │ │ ├── CmdXmlParser.cc │ │ │ ├── CmdXmlParser.hh │ │ │ ├── ConfigUpdateCmd.cc │ │ │ ├── ConfigUpdateCmd.hh │ │ │ ├── DirectoryIter.cc │ │ │ ├── DirectoryIter.hh │ │ │ ├── EventData.cc │ │ │ ├── EventData.hh │ │ │ ├── EventEntry.cc │ │ │ ├── EventEntry.hh │ │ │ ├── EventHubCmd.cc │ │ │ ├── EventHubCmd.hh │ │ │ ├── EventHubPublisher.cc │ │ │ ├── EventHubPublisher.hh │ │ │ ├── EventHubType.cc │ │ │ ├── EventHubType.hh │ │ │ ├── EventHubUploader.cc │ │ │ ├── EventHubUploader.hh │ │ │ ├── EventHubUploaderId.cc │ │ │ ├── EventHubUploaderId.hh │ │ │ ├── EventHubUploaderMgr.cc │ │ │ ├── EventHubUploaderMgr.hh │ │ │ ├── EventPersistMgr.cc │ │ │ ├── EventPersistMgr.hh │ │ │ ├── MdsBlobReader.cc │ │ │ ├── MdsBlobReader.hh │ │ │ ├── MdsCmdLogger.hh │ │ │ ├── MdsException.cc │ │ │ ├── MdsException.hh │ │ │ ├── PersistFiles.cc │ │ │ ├── PersistFiles.hh │ │ │ ├── PublisherStatus.cc │ │ │ ├── PublisherStatus.hh │ │ │ └── commands.xsd │ │ ├── mdsd/ │ │ │ ├── Batch.cc │ │ │ ├── Batch.hh │ │ │ ├── CMakeLists.txt │ │ │ ├── CanonicalEntity.cc │ │ │ ├── CanonicalEntity.hh │ │ │ ├── CfgContext.cc │ │ │ ├── CfgContext.hh │ │ │ ├── CfgCtxAccounts.cc │ │ │ ├── CfgCtxAccounts.hh │ │ │ ├── CfgCtxDerived.cc │ │ │ ├── CfgCtxDerived.hh │ │ │ ├── CfgCtxEnvelope.cc │ │ │ ├── CfgCtxEnvelope.hh │ │ │ ├── CfgCtxError.cc │ │ │ ├── CfgCtxError.hh │ │ │ ├── CfgCtxEtw.cc │ │ │ ├── CfgCtxEtw.hh │ │ │ ├── CfgCtxEventAnnotations.cc │ │ │ ├── CfgCtxEventAnnotations.hh │ │ │ ├── CfgCtxEvents.cc │ │ │ ├── CfgCtxEvents.hh │ │ │ ├── CfgCtxExtensions.cc │ │ │ ├── CfgCtxExtensions.hh │ │ │ ├── CfgCtxHeartBeats.cc │ │ │ ├── CfgCtxHeartBeats.hh │ │ │ ├── CfgCtxImports.cc │ │ │ ├── CfgCtxImports.hh │ │ │ ├── CfgCtxManagement.cc │ │ │ ├── CfgCtxManagement.hh │ │ │ ├── CfgCtxMdsdEvents.cc │ │ │ ├── CfgCtxMdsdEvents.hh │ │ │ ├── CfgCtxMonMgmt.cc │ │ │ ├── CfgCtxMonMgmt.hh │ │ │ ├── CfgCtxOMI.cc │ │ │ ├── CfgCtxOMI.hh │ │ │ ├── CfgCtxParser.cc │ │ │ ├── CfgCtxParser.hh │ │ │ ├── CfgCtxRoot.cc │ │ │ ├── CfgCtxRoot.hh │ │ │ ├── CfgCtxSchemas.cc │ │ │ ├── CfgCtxSchemas.hh │ │ │ ├── CfgCtxSources.cc │ │ │ ├── CfgCtxSources.hh │ │ │ ├── CfgCtxSvcBusAccts.cc │ │ │ ├── CfgCtxSvcBusAccts.hh │ │ │ ├── CfgEventAnnotationType.hh │ │ │ ├── CfgOboDirectConfig.hh │ │ │ ├── CmdLineConverter.cc │ │ │ ├── CmdLineConverter.hh │ │ │ ├── ConfigParser.cc │ │ │ ├── ConfigParser.hh │ │ │ ├── Constants.cc │ │ │ ├── Constants.hh │ │ │ ├── Credentials.cc │ │ │ ├── Credentials.hh │ │ │ ├── DaemonConf.cc │ │ │ ├── DaemonConf.hh │ │ │ ├── DerivedEvent.cc │ │ │ ├── DerivedEvent.hh │ │ │ ├── Engine.cc │ │ │ ├── Engine.hh │ │ │ ├── EtwEvent.cc │ │ │ ├── EtwEvent.hh │ │ │ ├── EventJSON.cc │ │ │ ├── EventJSON.hh │ │ │ ├── ExtensionMgmt.cc │ │ │ ├── ExtensionMgmt.hh │ │ │ ├── FileSink.cc │ │ │ ├── FileSink.hh │ │ │ ├── IMdsSink.cc │ │ │ ├── IMdsSink.hh │ │ │ ├── ITask.cc │ │ │ ├── ITask.hh │ │ │ ├── IdentityColumns.hh │ │ │ ├── LADQuery.cc │ │ │ ├── LADQuery.hh │ │ │ ├── LinuxMdsConfig.xsd │ │ │ ├── Listener.cc │ │ │ ├── Listener.hh │ │ │ ├── LocalSink.cc │ │ │ ├── LocalSink.hh │ │ │ ├── MdsBlobOutputter.hh │ │ │ ├── MdsEntityName.cc │ │ │ ├── MdsEntityName.hh │ │ │ ├── MdsSchemaMetadata.cc │ │ │ ├── MdsSchemaMetadata.hh │ │ │ ├── MdsValue.cc │ │ │ ├── MdsValue.hh │ │ │ ├── MdsdConfig.cc │ │ │ ├── MdsdConfig.hh │ │ │ ├── MdsdExtension.hh │ │ │ ├── MdsdMetrics.cc │ │ │ ├── MdsdMetrics.hh │ │ │ ├── Memcheck.cc │ │ │ ├── OMIQuery.cc │ │ │ ├── OMIQuery.hh │ │ │ ├── OmiTask.cc │ │ │ ├── OmiTask.hh │ │ │ ├── PipeStages.cc │ │ │ ├── PipeStages.hh │ │ │ ├── Pipeline.cc │ │ │ ├── Pipeline.hh │ │ │ ├── PoolMgmt.hh │ │ │ ├── Priority.cc │ │ │ ├── Priority.hh │ │ │ ├── ProtocolHandlerBase.cc │ │ │ ├── ProtocolHandlerBase.hh │ │ │ ├── ProtocolHandlerBond.cc │ │ │ ├── ProtocolHandlerBond.hh │ │ │ ├── ProtocolHandlerJSON.cc │ │ │ ├── ProtocolHandlerJSON.hh │ │ │ ├── ProtocolListener.cc │ │ │ ├── ProtocolListener.hh │ │ │ ├── ProtocolListenerBond.cc │ │ │ ├── ProtocolListenerBond.hh │ │ │ ├── ProtocolListenerDynamicJSON.cc │ │ │ ├── ProtocolListenerDynamicJSON.hh │ │ │ ├── ProtocolListenerJSON.cc │ │ │ ├── ProtocolListenerJSON.hh │ │ │ ├── ProtocolListenerMgr.cc │ │ │ ├── ProtocolListenerMgr.hh │ │ │ ├── ProtocolListenerTcpJSON.cc │ │ │ ├── ProtocolListenerTcpJSON.hh │ │ │ ├── RowIndex.cc │ │ │ ├── RowIndex.hh │ │ │ ├── SaxParserBase.cc │ │ │ ├── SaxParserBase.hh │ │ │ ├── SchemaCache.cc │ │ │ ├── SchemaCache.hh │ │ │ ├── Signals.c │ │ │ ├── StoreType.cc │ │ │ ├── StoreType.hh │ │ │ ├── StreamListener.cc │ │ │ ├── StreamListener.hh │ │ │ ├── Subscription.cc │ │ │ ├── Subscription.hh │ │ │ ├── TableColumn.cc │ │ │ ├── TableColumn.hh │ │ │ ├── TableSchema.cc │ │ │ ├── TableSchema.hh │ │ │ ├── TermHandler.cc │ │ │ ├── Version.cc │ │ │ ├── Version.hh │ │ │ ├── XJsonBlobBlockCountsMgr.cc │ │ │ ├── XJsonBlobBlockCountsMgr.hh │ │ │ ├── XJsonBlobRequest.cc │ │ │ ├── XJsonBlobRequest.hh │ │ │ ├── XJsonBlobSink.cc │ │ │ ├── XJsonBlobSink.hh │ │ │ ├── XTableConst.cc │ │ │ ├── XTableConst.hh │ │ │ ├── XTableHelper.cc │ │ │ ├── XTableHelper.hh │ │ │ ├── XTableRequest.cc │ │ │ ├── XTableRequest.hh │ │ │ ├── XTableSink.cc │ │ │ ├── XTableSink.hh │ │ │ ├── cJSON.c │ │ │ ├── cJSON.h │ │ │ ├── cryptutil.cc │ │ │ ├── cryptutil.hh │ │ │ ├── fdelt_chk.c │ │ │ ├── mdsautokey.h │ │ │ ├── mdsd.cc │ │ │ └── wrap_memcpy.c │ │ ├── mdsd.8 │ │ ├── mdsdcfg/ │ │ │ ├── CMakeLists.txt │ │ │ ├── EventPubCfg.cc │ │ │ ├── EventPubCfg.hh │ │ │ ├── EventSinkCfgInfo.hh │ │ │ ├── EventType.hh │ │ │ ├── MdsdEventCfg.cc │ │ │ └── MdsdEventCfg.hh │ │ ├── mdsdinput/ │ │ │ ├── CMakeLists.txt │ │ │ ├── MdsdInputMessageBuilder.cpp │ │ │ ├── MdsdInputMessageBuilder.h │ │ │ ├── MdsdInputMessageDecoder.h │ │ │ ├── MdsdInputMessageIO.cpp │ │ │ ├── MdsdInputMessageIO.h │ │ │ ├── MdsdInputSchemaCache.cpp │ │ │ ├── MdsdInputSchemaCache.h │ │ │ ├── mdsd_input.bond │ │ │ ├── mdsd_input_apply.cpp │ │ │ ├── mdsd_input_apply.h │ │ │ ├── mdsd_input_reflection.h │ │ │ ├── mdsd_input_types.cpp │ │ │ └── mdsd_input_types.h │ │ ├── mdsdlog/ │ │ │ ├── CMakeLists.txt │ │ │ ├── Logger.cc │ │ │ ├── Logger.hh │ │ │ ├── Trace.cc │ │ │ └── Trace.hh │ │ ├── mdsdutil/ │ │ │ ├── AzureUtility.cc │ │ │ ├── AzureUtility.hh │ │ │ ├── CMakeLists.txt │ │ │ ├── Crypto.cc │ │ │ ├── Crypto.hh │ │ │ ├── HttpProxySetup.cc │ │ │ ├── HttpProxySetup.hh │ │ │ ├── MdsTime.cc │ │ │ ├── MdsTime.hh │ │ │ ├── OpensslCert.cc │ │ │ ├── OpensslCert.hh │ │ │ ├── OpensslCertStore.cc │ │ │ ├── OpensslCertStore.hh │ │ │ ├── Utility.cc │ │ │ └── Utility.hh │ │ ├── mdsrest/ │ │ │ ├── CMakeLists.txt │ │ │ ├── GcsJsonData.cc │ │ │ ├── GcsJsonData.hh │ │ │ ├── GcsJsonParser.cc │ │ │ ├── GcsJsonParser.hh │ │ │ ├── GcsServiceInfo.cc │ │ │ ├── GcsServiceInfo.hh │ │ │ ├── GcsUtil.cc │ │ │ ├── GcsUtil.hh │ │ │ ├── MdsConst.hh │ │ │ ├── MdsRest.cc │ │ │ ├── MdsRest.hh │ │ │ └── MdsRestException.hh │ │ └── parseglibc.py │ ├── mocks/ │ │ ├── Readme.txt │ │ ├── __init__.py │ │ ├── crypt.py │ │ ├── fcntl.py │ │ └── pwd.py │ ├── run_unittests.sh │ ├── services/ │ │ ├── mdsd-lde.service │ │ ├── metrics-extension.service │ │ └── metrics-sourcer.service │ ├── shim.sh │ ├── tests/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── lad_2_3_compatible_portal_pub_settings.json │ │ ├── test_LadDiagnosticUtil.py │ │ ├── test_builtin.py │ │ ├── test_commonActions.py │ │ ├── test_lad_config_all.py │ │ ├── test_lad_ext_settings.py │ │ ├── test_lad_logging_config.py │ │ ├── var_lib_waagent/ │ │ │ └── lad_dir/ │ │ │ └── config/ │ │ │ ├── lad_settings_logging.json │ │ │ └── lad_settings_metric.json │ │ └── watchertests.py │ ├── virtual-machines-linux-diagnostic-extension-v3.md │ └── watcherutil.py ├── LAD-AMA-Common/ │ ├── metrics_ext_utils/ │ │ ├── __init__.py │ │ ├── metrics_common_utils.py │ │ ├── metrics_constants.py │ │ └── metrics_ext_handler.py │ └── telegraf_utils/ │ ├── __init__.py │ ├── telegraf_config_handler.py │ └── telegraf_name_map.py ├── LICENSE.txt ├── Makefile ├── OSPatching/ │ ├── HandlerManifest.json │ ├── README.md │ ├── azure/ │ │ ├── __init__.py │ │ ├── azure.pyproj │ │ ├── http/ │ │ │ ├── __init__.py │ │ │ ├── batchclient.py │ │ │ ├── httpclient.py │ │ │ └── winhttp.py │ │ ├── servicebus/ │ │ │ ├── __init__.py │ │ │ └── servicebusservice.py │ │ ├── servicemanagement/ │ │ │ ├── __init__.py │ │ │ ├── servicebusmanagementservice.py │ │ │ ├── servicemanagementclient.py │ │ │ ├── servicemanagementservice.py │ │ │ ├── sqldatabasemanagementservice.py │ │ │ └── websitemanagementservice.py │ │ └── storage/ │ │ ├── __init__.py │ │ ├── blobservice.py │ │ ├── cloudstorageaccount.py │ │ ├── queueservice.py │ │ ├── sharedaccesssignature.py │ │ ├── storageclient.py │ │ └── tableservice.py │ ├── check.py │ ├── handler.py │ ├── manifest.xml │ ├── oneoff/ │ │ └── __init__.py │ ├── patch/ │ │ ├── AbstractPatching.py │ │ ├── ConfigOptions.py │ │ ├── OraclePatching.py │ │ ├── SuSEPatching.py │ │ ├── UbuntuPatching.py │ │ ├── __init__.py │ │ ├── centosPatching.py │ │ └── redhatPatching.py │ ├── references │ ├── scheduled/ │ │ ├── __init__.py │ │ └── history │ └── test/ │ ├── FakePatching.py │ ├── FakePatching2.py │ ├── FakePatching3.py │ ├── HandlerEnvironment.json │ ├── README.txt │ ├── check.py │ ├── config/ │ │ └── 0.settings │ ├── default.settings │ ├── handler.py │ ├── oneoff/ │ │ └── __init__.py │ ├── prepare_settings.py │ ├── scheduled/ │ │ ├── __init__.py │ │ └── history │ ├── test.crt │ ├── test.prv │ ├── test_handler_1.py │ ├── test_handler_2.py │ └── test_handler_3.py ├── OmsAgent/ │ ├── .gitignore │ ├── HandlerManifest.json │ ├── ImportGPGkey.sh │ ├── README.md │ ├── apply_version.sh │ ├── extension-test/ │ │ ├── README.md │ │ ├── oms_extension_tests.py │ │ ├── omsfiles/ │ │ │ ├── apache_access.log │ │ │ ├── custom.log │ │ │ ├── customlog.conf │ │ │ ├── error.log │ │ │ ├── mysql-slow.log │ │ │ ├── mysql.log │ │ │ ├── oms_extension_run_script.py │ │ │ ├── perf.conf │ │ │ └── rsyslog-oms.conf │ │ ├── parameters.json │ │ └── verify_e2e.py │ ├── keys/ │ │ ├── dscgpgkey.asc │ │ └── msgpgkey.asc │ ├── manifest.xml │ ├── omsagent.py │ ├── omsagent.version │ ├── omsagent_shim.sh │ ├── packaging.sh │ ├── references │ ├── test/ │ │ ├── MockUtil.py │ │ ├── env.py │ │ └── test_install.py │ ├── update_version.sh │ └── watcherutil.py ├── RDMAUpdate/ │ ├── MANIFEST.in │ ├── RDMAUpdate.pyproj │ ├── README.txt │ ├── enableit.js │ ├── main/ │ │ ├── CommandExecuter.py │ │ ├── Common.py │ │ ├── CronUtil.py │ │ ├── RDMALogger.py │ │ ├── RdmaException.py │ │ ├── SecondStageMarkConfig.py │ │ ├── Utils/ │ │ │ ├── HandlerUtil.py │ │ │ ├── WAAgentUtil.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── handle.py │ │ └── patch/ │ │ ├── AbstractPatching.py │ │ ├── OraclePatching.py │ │ ├── SuSEPatching.py │ │ ├── UbuntuPatching.py │ │ ├── __init__.py │ │ ├── centosPatching.py │ │ └── redhatPatching.py │ ├── references │ ├── setup.py │ ├── test/ │ │ └── update_rdma_driver.py │ └── test.ps1 ├── README.md ├── SECURITY.md ├── SampleExtension/ │ ├── HandlerManifest.json │ ├── disable.py │ ├── enable.py │ ├── install.py │ ├── references │ ├── uninstall.py │ └── update.py ├── TestHandlerLinux/ │ ├── HandlerManifest.json │ ├── bin/ │ │ ├── #heartbeat.py# │ │ ├── disable.py │ │ ├── enable.py │ │ ├── heartbeat.py │ │ ├── service.py │ │ └── update.py │ ├── installer/ │ │ ├── install.py │ │ └── uninstall.py │ ├── manifest.xml │ ├── references │ └── resources/ │ ├── HandlerUtil.py │ └── mypydoc.py ├── Utils/ │ ├── HandlerUtil.py │ ├── LogUtil.py │ ├── ScriptUtil.py │ ├── WAAgentUtil.py │ ├── __init__.py │ ├── constants.py │ ├── crypt_fallback.py │ ├── distroutils.py │ ├── extensionutils.py │ ├── handlerutil2.py │ ├── logger.py │ ├── ovfutils.py │ └── test/ │ ├── MockUtil.py │ ├── env.py │ ├── mock.sh │ ├── mock_sshd_config │ ├── non_latin_characters.txt │ ├── ovf-env-empty.xml │ ├── ovf-env.xml │ ├── place_vmaccess_on_local_machine.sh │ ├── test_encode.py │ ├── test_extensionutils_code_injection.py │ ├── test_logutil.py │ ├── test_null_protected_settings.py │ ├── test_ovf_utils.py │ ├── test_redacted_settings.py │ └── test_scriptutil.py ├── VMAccess/ │ ├── CHANGELOG.md │ ├── HandlerManifest.json │ ├── README.md │ ├── extension_noop.sh │ ├── extension_shim.sh │ ├── manifest.xml │ ├── references │ ├── resources/ │ │ ├── SuSE_default │ │ ├── Ubuntu_default │ │ ├── centos_default │ │ ├── debian_default │ │ ├── default │ │ ├── fedora_default │ │ └── redhat_default │ ├── test/ │ │ ├── env.py │ │ ├── test_iptable_rules.py │ │ ├── test_reset_account.py │ │ └── test_reset_sshd_config.py │ └── vmaccess.py ├── VMBackup/ │ ├── .gitignore │ ├── HandlerManifest.json │ ├── MANIFEST.in │ ├── README.txt │ ├── VMBackup.pyproj │ ├── debughelper/ │ │ ├── README.md │ │ ├── checkMounts.go │ │ ├── go.mod │ │ ├── go.sum │ │ ├── main.go │ │ └── run.go │ ├── main/ │ │ ├── ExtensionErrorCodeHelper.py │ │ ├── HttpUtil.py │ │ ├── IaaSExtensionSnapshotService/ │ │ │ ├── README.md │ │ │ ├── SnapshotServiceConstants.py │ │ │ ├── SnapshotServiceContracts.py │ │ │ ├── __init__.py │ │ │ └── service_metadata.json │ │ ├── LogSeverity.json │ │ ├── MachineIdentity.py │ │ ├── PluginHost.py │ │ ├── ScriptRunner.py │ │ ├── Utils/ │ │ │ ├── DiskUtil.py │ │ │ ├── Event.py │ │ │ ├── EventLoggerUtil.py │ │ │ ├── HandlerUtil.py │ │ │ ├── HostSnapshotObjects.py │ │ │ ├── LogHelper.py │ │ │ ├── ResourceDiskUtil.py │ │ │ ├── SizeCalculation.py │ │ │ ├── Status.py │ │ │ ├── StringHelper.py │ │ │ ├── WAAgentUtil.py │ │ │ ├── __init__.py │ │ │ └── dhcpUtils.py │ │ ├── VMSnapshotPluginHost.conf │ │ ├── WaagentLib.py │ │ ├── __init__.py │ │ ├── backuplogger.py │ │ ├── blobwriter.py │ │ ├── common.py │ │ ├── dhcpHandler.py │ │ ├── freezesnapshotter.py │ │ ├── fsfreezer.py │ │ ├── guestsnapshotter.py │ │ ├── handle.py │ │ ├── handle.sh │ │ ├── handle_host_daemon.py │ │ ├── handle_host_daemon.sh │ │ ├── hostsnapshotter.py │ │ ├── mounts.py │ │ ├── parameterparser.py │ │ ├── patch/ │ │ │ ├── AbstractPatching.py │ │ │ ├── DefaultPatching.py │ │ │ ├── FreeBSDPatching.py │ │ │ ├── KaliPatching.py │ │ │ ├── NSBSDPatching.py │ │ │ ├── SuSEPatching.py │ │ │ ├── UbuntuPatching.py │ │ │ ├── __init__.py │ │ │ ├── centosPatching.py │ │ │ ├── debianPatching.py │ │ │ ├── oraclePatching.py │ │ │ └── redhatPatching.py │ │ ├── safefreeze/ │ │ │ ├── Makefile │ │ │ └── src/ │ │ │ └── safefreeze.c │ │ ├── safefreezeArm64/ │ │ │ ├── Makefile │ │ │ ├── bin/ │ │ │ │ └── safefreeze │ │ │ └── src/ │ │ │ └── safefreeze.c │ │ ├── taskidentity.py │ │ ├── tempPlugin/ │ │ │ ├── VMSnapshotScriptPluginConfig.json │ │ │ ├── postScript.sh │ │ │ ├── preScript.sh │ │ │ └── vmbackup.conf │ │ └── workloadPatch/ │ │ ├── CustomScripts/ │ │ │ └── customscript.sql │ │ ├── DefaultScripts/ │ │ │ ├── logbackup.sql │ │ │ ├── postMysqlMaster.sql │ │ │ ├── postMysqlSlave.sql │ │ │ ├── postOracleMaster.sql │ │ │ ├── postPostgresMaster.sql │ │ │ ├── preMysqlMaster.sql │ │ │ ├── preMysqlSlave.sql │ │ │ ├── preOracleMaster.sql │ │ │ ├── prePostgresMaster.sql │ │ │ └── timeoutDaemon.sh │ │ ├── LogBackupPatch.py │ │ ├── WorkloadPatch.py │ │ ├── WorkloadUtils/ │ │ │ ├── OracleLogBackup.py │ │ │ ├── OracleLogRestore.py │ │ │ └── workload.conf │ │ └── __init__.py │ ├── manifest.xml │ ├── mkstub.py │ ├── references │ ├── setup.py │ └── test/ │ ├── handle.py │ └── install_python2.6.sh ├── VMEncryption/ │ ├── .vscode/ │ │ └── settings.json │ ├── MANIFEST.in │ ├── ReleaseNotes.txt │ ├── Test-AzureRmVMDiskEncryptionExtension.ps1 │ ├── Test-AzureRmVMDiskEncryptionExtensionDiskFormat.ps1 │ ├── UpgradeLog.htm │ ├── VMEncryption.pyproj │ ├── extension_shim.sh │ ├── lint_output.txt │ ├── main/ │ │ ├── BackupLogger.py │ │ ├── BekUtil.py │ │ ├── CommandExecutor.py │ │ ├── Common.py │ │ ├── ConfigUtil.py │ │ ├── DecryptionMarkConfig.py │ │ ├── DiskUtil.py │ │ ├── EncryptionConfig.py │ │ ├── EncryptionEnvironment.py │ │ ├── EncryptionMarkConfig.py │ │ ├── ExtensionParameter.py │ │ ├── HttpUtil.py │ │ ├── KeyVaultUtil.py │ │ ├── MachineIdentity.py │ │ ├── OnGoingItemConfig.py │ │ ├── ProcessLock.py │ │ ├── ResourceDiskUtil.py │ │ ├── SupportedOS.json │ │ ├── TokenUtil.py │ │ ├── TransactionalCopyTask.py │ │ ├── Utils/ │ │ │ ├── HandlerUtil.py │ │ │ ├── WAAgentUtil.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── check_util.py │ │ ├── handle.py │ │ ├── oscrypto/ │ │ │ ├── 91ade/ │ │ │ │ ├── 50-udev-ade.rules │ │ │ │ ├── cryptroot-ask-ade.sh │ │ │ │ ├── module-setup.sh │ │ │ │ └── parse-crypt-ade.sh │ │ │ ├── OSEncryptionState.py │ │ │ ├── OSEncryptionStateMachine.py │ │ │ ├── __init__.py │ │ │ ├── centos_68/ │ │ │ │ ├── CentOS68EncryptionStateMachine.py │ │ │ │ ├── __init__.py │ │ │ │ ├── encryptpatches/ │ │ │ │ │ └── centos_68_dracut.patch │ │ │ │ └── encryptstates/ │ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ │ ├── PatchBootSystemState.py │ │ │ │ ├── PrereqState.py │ │ │ │ ├── SelinuxState.py │ │ │ │ ├── SplitRootPartitionState.py │ │ │ │ ├── StripdownState.py │ │ │ │ ├── UnmountOldrootState.py │ │ │ │ └── __init__.py │ │ │ ├── rhel_68/ │ │ │ │ ├── RHEL68EncryptionStateMachine.py │ │ │ │ ├── __init__.py │ │ │ │ ├── encryptpatches/ │ │ │ │ │ └── rhel_68_dracut.patch │ │ │ │ └── encryptstates/ │ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ │ ├── PatchBootSystemState.py │ │ │ │ ├── PrereqState.py │ │ │ │ ├── SelinuxState.py │ │ │ │ ├── StripdownState.py │ │ │ │ ├── UnmountOldrootState.py │ │ │ │ └── __init__.py │ │ │ ├── rhel_72/ │ │ │ │ ├── RHEL72EncryptionStateMachine.py │ │ │ │ ├── __init__.py │ │ │ │ └── encryptstates/ │ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ │ ├── PatchBootSystemState.py │ │ │ │ ├── PrereqState.py │ │ │ │ ├── SelinuxState.py │ │ │ │ ├── StripdownState.py │ │ │ │ ├── UnmountOldrootState.py │ │ │ │ └── __init__.py │ │ │ ├── rhel_72_lvm/ │ │ │ │ ├── RHEL72LVMEncryptionStateMachine.py │ │ │ │ ├── __init__.py │ │ │ │ └── encryptstates/ │ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ │ ├── PatchBootSystemState.py │ │ │ │ ├── PrereqState.py │ │ │ │ ├── SelinuxState.py │ │ │ │ ├── StripdownState.py │ │ │ │ ├── UnmountOldrootState.py │ │ │ │ └── __init__.py │ │ │ ├── ubuntu_1404/ │ │ │ │ ├── Ubuntu1404EncryptionStateMachine.py │ │ │ │ ├── __init__.py │ │ │ │ ├── encryptpatches/ │ │ │ │ │ └── ubuntu_1404_initramfs.patch │ │ │ │ ├── encryptscripts/ │ │ │ │ │ ├── azure_crypt_key.sh │ │ │ │ │ └── inject_luks_header.sh │ │ │ │ └── encryptstates/ │ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ │ ├── PatchBootSystemState.py │ │ │ │ ├── PrereqState.py │ │ │ │ ├── SelinuxState.py │ │ │ │ ├── SplitRootPartitionState.py │ │ │ │ ├── StripdownState.py │ │ │ │ ├── UnmountOldrootState.py │ │ │ │ └── __init__.py │ │ │ └── ubuntu_1604/ │ │ │ ├── Ubuntu1604EncryptionStateMachine.py │ │ │ ├── __init__.py │ │ │ ├── encryptpatches/ │ │ │ │ └── ubuntu_1604_initramfs.patch │ │ │ ├── encryptscripts/ │ │ │ │ ├── azure_crypt_key.sh │ │ │ │ └── inject_luks_header.sh │ │ │ └── encryptstates/ │ │ │ ├── EncryptBlockDeviceState.py │ │ │ ├── PatchBootSystemState.py │ │ │ ├── PrereqState.py │ │ │ ├── SelinuxState.py │ │ │ ├── SplitRootPartitionState.py │ │ │ ├── StripdownState.py │ │ │ ├── UnmountOldrootState.py │ │ │ └── __init__.py │ │ └── patch/ │ │ ├── AbstractPatching.py │ │ ├── SuSEPatching.py │ │ ├── UbuntuPatching.py │ │ ├── __init__.py │ │ ├── centosPatching.py │ │ ├── debianPatching.py │ │ ├── oraclePatching.py │ │ └── redhatPatching.py │ ├── references │ ├── requirements.txt │ ├── setup.py │ └── test/ │ ├── __init__.py │ ├── console_logger.py │ ├── test_check_util.py │ ├── test_command_executor.py │ ├── test_disk_util.py │ ├── test_handler_util.py │ ├── test_resource_disk_util.py │ └── test_utils.py ├── docs/ │ ├── advanced-topics.md │ ├── contribution-guide.md │ ├── design-details.md │ ├── document.md │ ├── handler-registration.md │ ├── overview.md │ ├── sample-extension.md │ ├── test.md │ └── utils.md ├── go.mod ├── go.sum ├── registration-scripts/ │ ├── api/ │ │ ├── add-extension.sh │ │ ├── check-request-status.sh │ │ ├── del-extension.sh │ │ ├── get-extension.sh │ │ ├── get-subscription.sh │ │ ├── list-extension.sh │ │ ├── params │ │ └── update-extension.sh │ ├── bin/ │ │ ├── add.sh │ │ ├── blob/ │ │ │ ├── list.sh │ │ │ └── upload.sh │ │ ├── check.sh │ │ ├── del.sh │ │ ├── get.sh │ │ ├── list.sh │ │ ├── subscription.sh │ │ └── update.sh │ ├── create_zip.sh │ ├── mooncake/ │ │ └── sample-extension-1.0.xml │ └── public/ │ └── sample-extension-1.0.xml ├── script/ │ ├── 0.settings │ ├── HandlerEnvironment.json │ ├── create_zip.sh │ ├── mkstub.sh │ ├── ovf-env.xml │ ├── set_env.sh │ ├── test.crt │ └── test.prv └── ui-extension-packages/ ├── microsoft.custom-script-linux/ │ ├── Artifacts/ │ │ ├── CreateUiDefinition.json │ │ └── MainTemplate.json │ ├── Manifest.json │ ├── Strings/ │ │ └── resources.resjson │ └── UiDefinition.json └── microsoft.custom-script-linux-arm/ ├── Artifacts/ │ ├── CreateUiDefinition.json │ └── MainTemplate.json ├── Manifest.json ├── Strings/ │ └── resources.resjson └── UiDefinition.json ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ ############################################################################### # Set default behavior to automatically normalize line endings. ############################################################################### * text=auto ############################################################################### # Set default behavior for command prompt diff. # # This is need for earlier builds of msysgit that does not have it on by # default for csharp files. # Note: This is only used by command line ############################################################################### #*.cs diff=csharp ############################################################################### # Set the merge driver for project and solution files # # Merging from the command prompt will add diff markers to the files if there # are conflicts (Merging from VS is not affected by the settings below, in VS # the diff markers are never inserted). Diff markers may cause the following # file extensions to fail to load in VS. An alternative would be to treat # these files as binary and thus will always conflict and require user # intervention with every merge. To do so, just uncomment the entries below ############################################################################### #*.sln merge=binary #*.csproj merge=binary #*.vbproj merge=binary #*.vcxproj merge=binary #*.vcproj merge=binary #*.dbproj merge=binary #*.fsproj merge=binary #*.lsproj merge=binary #*.wixproj merge=binary #*.modelproj merge=binary #*.sqlproj merge=binary #*.wwaproj merge=binary ############################################################################### # behavior for image files # # image files are treated as binary by default. ############################################################################### #*.jpg binary #*.png binary #*.gif binary ############################################################################### # diff behavior for common document formats # # Convert binary document formats to text before diffing them. This feature # is only available from the command line. Turn it on by uncommenting the # entries below. ############################################################################### #*.doc diff=astextplain #*.DOC diff=astextplain #*.docx diff=astextplain #*.DOCX diff=astextplain #*.dot diff=astextplain #*.DOT diff=astextplain #*.pdf diff=astextplain #*.PDF diff=astextplain #*.rtf diff=astextplain #*.RTF diff=astextplain ================================================ FILE: .gitignore ================================================ DSC/DSC.zip compiled / optimized / DLL files __pycache__/ *.py[cod] */.py[cod] # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # Editor *~ # PyCharm .idea/ .idea_modules/ # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .cache nosetests.xml coverage.xml # Translations *.mo *.pot # Django stuff: *.log !OmsAgent/extension-test/omsfiles/*.log oms*.zip # Sphinx documentation docs/_build/ # PyBuilder target/* # mac osx specific files .DS_Store ### VirtualEnv template # Virtualenv # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ .Python pyvenv.cfg .venv pip-selfcheck.json # virtualenv venv/ ENV/ # dotenv .env # pyenv .python-version # VMBackup package ignors VMBackup/dist VMBackup/dist/* VMBackup/azure-sdk VMBackup/azure-sdk/* VMBackup/main/azure/* VMBackup/MANIFEST #VMBackup/*.pyproj VMBackup/*.pyproj.user VMBackup/*.suo VMBackup/main/safefreeze/bin/* # CustomScript ignors CustomScript/test/download/0/stdout CustomScript/test/download/0/errout *node_modules/ # VMEncryption ignores VMEncryption/main/azure/* # Common Common/psutil/build/* Common/psutil/dist/* Common/psutil/psutil.egg-info/* VMBackup/.vs/VMBackup/v14/.suo RDMAUpdate/RDMAUpdate.pyproj.user *.sln RDMAUpdate/VMBackup.pyproj.user VMBackup/.vs/config/applicationhost.config RDMAUpdate/.vs/VMBackup/v14/.suo # Handler Registration ignores *.pem RDMAUpdate/.vs/RDMAUpdate/v14/.suo # Visual Studio directory .vs/ # Ignore HandlerManifest updates VMEncryption/HandlerManifest.json VMEncryption/AzureDiskEncryptionForLinux*.xml VMEncryption/ADEForLinux*.xml VMEncryption/MANIFEST ================================================ FILE: .gitmodules ================================================ [submodule "Common/azure-sdk-for-python"] path = Common/azure-sdk-for-python url = https://github.com/Azure/azure-sdk-for-python.git [submodule "Common/psutil"] path = Common/psutil url = https://github.com/giampaolo/psutil.git [submodule "VMEncryption/transitions"] path = VMEncryption/transitions url = https://github.com/tyarkoni/transitions.git ================================================ FILE: .vscode/launch.json ================================================ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { "name": "Python: test_encode.py", "type": "python", "request": "launch", "program": "test_encode.py", "console": "integratedTerminal", "justMyCode": true, "cwd": "${workspaceFolder}/Utils/test", "env" : { "PYTHONPATH": "${workspaceFolder}" } } ] } ================================================ FILE: AzureEnhancedMonitor/README.md ================================================ # How to enable Azure Enhanced Monitoring on Linux VM This is an instruction about how to enable Azure Enhanced Monitoring(AEM) on Azure Linux VM. ## Install Azure CLI First of all, you need to to install [Azure CLI][azure-cli] **NOTE** This feature is currently on developing. You need to install it from github by running the following command. ``` npm -g install git+https://github.com/yuezh/azure-xplat-cli.git#dev ``` ## Configure Azure Enhanced Monitoring 1. Login with your Azure account ``` azure login ``` 2. Switch to azure resource management mode ``` azure config mode arm ``` 3. Enable Azure Enhanced Monitoring ``` azure vm enable-aem ``` 4. Verify that the Azure Enhanced Monitoring is active on the Azure Linux VM. Check if the file /var/lib/AzureEnhancedMonitor/PerfCounters exists. If exists, display information collected by AEM with: ``` cat /var/lib/AzureEnhancedMonitor/PerfCounters ``` Then you will get output like: ``` 2;cpu;Current Hw Frequency;;0;2194.659;MHz;60;1444036656;saplnxmon; 2;cpu;Max Hw Frequency;;0;2194.659;MHz;0;1444036656;saplnxmon; … … ``` [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ ================================================ FILE: AzureEnhancedMonitor/bin/pack.sh ================================================ #!/bin/bash proj_name="aem" proj_version="1.0" proj_full_name="$proj_name-$proj_version" script=$(dirname $0) root=$script/.. cd $root root=`pwd` build_dir=$root/build target_dir=$build_dir/$proj_full_name mkdir -p $build_dir mkdir -p $target_dir cd $root/clib make clean cp -r $root/nodejs $build_dir cd $build_dir/nodejs npm pack cp -r $root/clib $target_dir cp $build_dir/nodejs/*.tgz $target_dir cp $root/bin/setup.sh $target_dir chmod +x $root/bin/setup.sh #install.sh is a self-extracting script. #The begin of this file is sh script while the end is a tar echo "#!/bin/bash" > $build_dir/install.sh echo "#" >> $build_dir/install.sh echo "#Auto-generated. Do NOT edit this file." >> $build_dir/install.sh echo "#" >> $build_dir/install.sh echo "root=\$(dirname \$0)" >> $build_dir/install.sh echo "cd \$root" >> $build_dir/install.sh echo "root=\`pwd\`" >> $build_dir/install.sh echo "if [ -d $proj_full_name ]; then" >> $build_dir/install.sh echo " echo \"[INFO]Remove old package...\"" >> $build_dir/install.sh echo " rm $proj_full_name -rf" >> $build_dir/install.sh echo "fi" >> $build_dir/install.sh echo "echo \"[INFO]Unpacking...\"" >> $build_dir/install.sh echo "sed -e '1,/^exit$/d' "\$0" | tar xzf -" >> $build_dir/install.sh echo "$proj_full_name/setup.sh" >> $build_dir/install.sh echo "exit" >> $build_dir/install.sh cd $build_dir tar czf - $proj_full_name >> $build_dir/install.sh chmod +x $build_dir/install.sh cp -r $root/clib $build_dir cd $build_dir tar czf clib.tar.gz clib/ ================================================ FILE: AzureEnhancedMonitor/bin/setup.sh ================================================ #!/bin/bash install_log=`pwd`/install.log root=$(dirname $0) cd $root root=`pwd` if [[ $EUID -ne 0 ]]; then echo "[ERROR]This script must be run as root" 1>&2 exit 1 fi function install_nodejs_tarball() { version="v0.10.37" node_version="node-$version-linux-x64" src="$root/$node_version" target="/usr/local" echo "[INFO]Installing nodejs from http://nodejs.org/dist/$version/${node_version}.tar.gz" if [ -f ${src}.tar.gz ]; then rm ${src}.tar.gz -f fi if [ -d ${src} ]; then rm ${src} -rf fi wget http://nodejs.org/dist/$version/${node_version}.tar.gz 1>>$install_log 2>&1 tar -zxf ${node_version}.tar.gz 1>>$install_log 2>&1 echo "[INFO]Install nodejs to $target" if [ -f $target/bin/node ]; then rm $target/bin/node -f fi cp $src/bin/node $target/bin/node echo "[INFO]Create link to $target/bin/node" if [ -f /usr/bin/node ]; then rm /usr/bin/node -f fi ln -s $target/bin/node /usr/bin/node echo "[INFO]Install npm" curl -sL https://www.npmjs.org/install.sh | sh 1>>$install_log 2>&1 } function install_nodejs() { echo "[INFO]Installing nodejs and npm" if [ "$(type apt-get 2>/dev/null)" != "" ] ; then curl -sL https://deb.nodesource.com/setup | bash - 1>>$install_log 2>&1 apt-get -y install nodejs 1>>$install_log 2>&1 elif [ "$(type yum 2>/dev/null)" != "" ] ; then curl -sL https://rpm.nodesource.com/setup | bash - 1>>$install_log 2>&1 yum -y install nodejs 1>>$install_log 2>&1 else install_nodejs_tarball fi if [ ! $? ]; then echo "[ERROR]Install nodejs and npm failed. See $install_log." exit 1 fi } echo "[INFO]Checking dependency..." echo "" > $install_log if [ "$(type node 2>/dev/null)" == "" ]; then install_nodejs fi echo "[INFO] nodejs version: $(node --version)" if [ "$(type npm 2>/dev/null)" == "" ]; then install_nodejs fi echo "[INFO] npm version: $(npm -version)" if [ "$(type azure 2> /dev/null)" == "" ]; then echo "[INFO]Installing azure-cli" npm install -g azure-cli 1>>$install_log 2>&1 if [ ! $? ]; then echo "[ERROR]Install azure-cli failed. See $install_log." exit 1 fi fi echo "[INFO] azure-cli version: $(azure --version)" npm_pkg="azure-linux-tools-1.0.0.tgz" echo "[INFO]Installing Azure Enhanced Monitor tools..." if [ -f ./$npm_pkg ]; then npm install -g ./$npm_pkg 1>>$install_log 2>&1 if [ ! $? ]; then echo "[ERROR]Install Azure Enhanced Monitor tools failed. See $install_log." exit 1 fi else echo "[ERROR] Couldn't find npm package $npm_pkg" exit 1 fi echo "[INFO]Finished." ================================================ FILE: AzureEnhancedMonitor/clib/.gitignore ================================================ bin/* ================================================ FILE: AzureEnhancedMonitor/clib/Makefile ================================================ CC := gcc SRCDIR := src LIBDIR := lib INCDIR := include BUILDDIR := build TARGET := $(LIBDIR)/libazureperf.so SRCEXT := c SOURCES := $(shell find $(SRCDIR) -type f -name *.$(SRCEXT)) OBJECTS := $(patsubst $(SRCDIR)/%,$(BUILDDIR)/%,$(SOURCES:.$(SRCEXT)=.o)) CFLAGS := -g -fPIC LDFLAGS := -shared INC := -I $(INCDIR) LIB := -L $(LIBDIR) all : $(TARGET) $(TARGET): $(OBJECTS) @echo "Linking..." $(CC) $^ $(LDFLAGS) -o $(TARGET) $(LIB) $(BUILDDIR)/%.o: $(SRCDIR)/%.$(SRCEXT) @mkdir -p $(BUILDDIR) @echo "Compiling..." $(CC) $(CFLAGS) $(INC) -c -o $@ $< clean: @echo "Cleaning..." $(RM) -r $(BUILDDIR) $(TARGET) test: $(OBJECTS) @echo "Run test" $(CC) test/runtest.c $^ $(INC) -L $(LIBDIR) -lazureperf -o bin/runtest bin/runtest install: mkdir -p /usr/lib/azureperf cp $(TARGET) /usr/lib/azureperf echo "/usr/lib/azureperf" > /etc/ld.so.conf.d/azureperf.conf ldconfig cp $(INCDIR)/azureperf.h /usr/include .PHONY: clean test ================================================ FILE: AzureEnhancedMonitor/clib/include/azureperf.h ================================================ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #ifndef AZURE_PERF #define AZURE_PERF /*All the strings are utf-8 encoded*/ /*The max buf size for all string*/ #define STR_BUF_MAX (256) #define TYPE_NAME_MAX (64) #define PROPERTY_NAME_MAX (128) #define INSTANCE_NAME_MAX (256) #define STRING_VALUE_MAX (256) #define UNIT_NAME_MAX (64) #define MACHINE_NAME_MAX (128) #define PERF_COUNT_MAX (128) #define PERF_COUNTER_TYPE_INVALID (0) #define PERF_COUNTER_TYPE_INT (1) #define PERF_COUNTER_TYPE_DOUBLE (2) #define PERF_COUNTER_TYPE_LARGE (3) #define PERF_COUNTER_TYPE_STRING (4) #define AP_ERR_PC_NOT_FOUND (-1) #define AP_ERR_PC_BUF_OVERFLOW (-2) #define AP_ERR_INVALID_COUNTER_TYPE (-11) #define AP_ERR_INVALID_TYPE_NAME (-12) #define AP_ERR_INVALID_PROPERTY_NAME (-13) #define AP_ERR_INVALID_INSTANCE_NAME (-14) #define AP_ERR_INVALID_IS_EMPTY_FLAG (-15) #define AP_ERR_INVALID_VALUE (-15) #define AP_ERR_INVALID_UNIT_NAME (-16) #define AP_ERR_INVALID_REFRESH_INTERVAL (-17) #define AP_ERR_INVALID_TIMESTAMP (-18) #define AP_ERR_INVALID_MACHINE_NAME (-19) typedef struct { int counter_typer; char type_name[TYPE_NAME_MAX]; char property_name[PROPERTY_NAME_MAX]; char instance_name[STRING_VALUE_MAX]; int is_empty; union { int val_int; long long val_large; double val_double; char val_str[STRING_VALUE_MAX]; }; char unit_name[UNIT_NAME_MAX]; unsigned int refresh_interval; long long timestamp; char machine_name[MACHINE_NAME_MAX]; } perf_counter; typedef struct { perf_counter buf[PERF_COUNT_MAX]; int len; int err; char *ap_file; } ap_handler; ap_handler* ap_open(); extern void ap_close(ap_handler* handler); extern void ap_refresh(ap_handler* handler); extern int ap_metric_all(ap_handler *handler, perf_counter *pc, size_t size); //config\Cloud Provider extern int ap_metric_config_cloud_provider(ap_handler *handler, perf_counter *pc, size_t size); //config\CPU Over-Provisioning extern int ap_metric_config_cpu_over_provisioning(ap_handler *handler, perf_counter *pc, size_t size); //config\Memory Over-Provisioning extern int ap_metric_config_memory_over_provisioning(ap_handler *handler, perf_counter *pc, size_t size); //config\Data Provider Version extern int ap_metric_config_data_provider_version(ap_handler *handler, perf_counter *pc, size_t size); //config\Data Sources extern int ap_metric_config_data_sources(ap_handler *handler, perf_counter *pc, size_t size); //config\Instance Type extern int ap_metric_config_instance_type(ap_handler *handler, perf_counter *pc, size_t size); //config\Virtualization Solution extern int ap_metric_config_virtualization_solution(ap_handler *handler, perf_counter *pc, size_t size); //config\Virtualization Solution Version extern int ap_metric_config_virtualization_solution_version(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Current Hw Frequency extern int ap_metric_cpu_current_hw_frequency(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Max Hw Frequency extern int ap_metric_cpu_max_hw_frequency(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Current VM Processing Power extern int ap_metric_cpu_current_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Guaranteed VM Processing Power extern int ap_metric_cpu_guaranteed_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Max. VM Processing Power extern int ap_metric_cpu_max_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Number of Cores per CPU extern int ap_metric_cpu_number_of_cores_per_cpu(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Number of Threads per Core extern int ap_metric_cpu_number_of_threads_per_core(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Phys. Processing Power per vCPU extern int ap_metric_cpu_phys_processing_power_per_vcpu(ap_handler *handler, perf_counter *pc, size_t size); //cpu\Processor Type extern int ap_metric_cpu_processor_type(ap_handler *handler, perf_counter *pc, size_t size); #endif ================================================ FILE: AzureEnhancedMonitor/clib/src/apmetric.c ================================================ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // This file is auto-generated, don't modify it directly. // #include #include int ap_metric_config_cloud_provider(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Cloud Provider", size); } int ap_metric_config_cpu_over_provisioning(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "CPU Over-Provisioning", size); } int ap_metric_config_memory_over_provisioning(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Memory Over-Provisioning", size); } int ap_metric_config_data_provider_version(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Data Provider Version", size); } int ap_metric_config_data_sources(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Data Sources", size); } int ap_metric_config_instance_type(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Instance Type", size); } int ap_metric_config_virtualization_solution(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Virtualization Solution", size); } int ap_metric_config_virtualization_solution_version(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Virtualization Solution Version", size); } int ap_metric_cpu_current_hw_frequency(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Current Hw Frequency", size); } int ap_metric_cpu_max_hw_frequency(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Max Hw Frequency", size); } int ap_metric_cpu_current_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Current VM Processing Power", size); } int ap_metric_cpu_guaranteed_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Guaranteed VM Processing Power", size); } int ap_metric_cpu_max_vm_processing_power(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Max. VM Processing Power", size); } int ap_metric_cpu_number_of_cores_per_cpu(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Number of Cores per CPU", size); } int ap_metric_cpu_number_of_threads_per_core(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Number of Threads per Core", size); } int ap_metric_cpu_phys_processing_power_per_vcpu(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Phys. Processing Power per vCPU", size); } int ap_metric_cpu_processor_type(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Processor Type", size); } int ap_metric_cpu_reference_compute_unit(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "Reference Compute Unit", size); } int ap_metric_cpu_vcpu_mapping(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "vCPU Mapping", size); } int ap_metric_cpu_vm_processing_power_consumption(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "cpu", "VM Processing Power Consumption", size); } int ap_metric_memory_current_memory_assigned(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "memory", "Current Memory assigned", size); } int ap_metric_memory_guaranteed_memory_assigned(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "memory", "Guaranteed Memory assigned", size); } int ap_metric_memory_max_memory_assigned(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "memory", "Max Memory assigned", size); } int ap_metric_memory_vm_memory_consumption(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "memory", "VM Memory Consumption", size); } int ap_metric_network_adapter_id(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Adapter Id", size); } int ap_metric_network_mapping(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Mapping", size); } int ap_metric_network_min_network_bandwidth(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Minimum Network Bandwidth", size); } int ap_metric_network_max_network_bandwidth(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Maximum Network Bandwidth", size); } int ap_metric_network_network_read_bytes(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Network Read Bytes", size); } int ap_metric_network_network_write_bytes(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Network Write Bytes", size); } int ap_metric_network_packets_retransmitted(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "network", "Packets Retransmitted", size); } int ap_metric_config_last_hardware_change(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "config", "Last Hardware Change", size); } int ap_metric_storage_phys_disc_to_storage_mapping(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Phys. Disc to Storage Mapping", size); } int ap_metric_storage_storage_id(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage ID", size); } int ap_metric_storage_read_bytes(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Read Bytes", size); } int ap_metric_storage_read_ops(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Read Ops", size); } int ap_metric_storage_read_op_latency_e2e(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Read Op Latency E2E msec", size); } int ap_metric_storage_read_op_latency_server(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Read Op Latency Server msec", size); } int ap_metric_storage_read_throughput_e2e(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Read Throughput E2E MB/sec", size); } int ap_metric_storage_write_bytes(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Write Bytes", size); } int ap_metric_storage_write_ops(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Write Ops", size); } int ap_metric_storage_write_op_latency_e2e(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Write Op Latency E2E msec", size); } int ap_metric_storage_write_op_latency_server(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Write Op Latency Server msec", size); } int ap_metric_storage_write_throughput_e2e(ap_handler *handler, perf_counter *pc, size_t size) { if(handler->err) { return 0; } return get_metric(handler, pc, "storage", "Storage Write Throughput E2E MB/sec", size); } ================================================ FILE: AzureEnhancedMonitor/clib/src/azureperf.c ================================================ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include #include #include #include #include #define INTMIN(X, Y) (((X) < (Y)) ? (X) : (Y)) #define INTMAX(X, Y) (((X) > (Y)) ? (X) : (Y)) #define MATCH_SUCCESS (1) #define MATCH_FAILED (0) #define MATCH_EOF (-1) #define STRICT_MATCH (1) #define NON_STRICT_MATCH (0) static char FIELD_SEPRATOR = ';'; static char DEFAULT_AP_FILE[] = "/var/lib/AzureEnhancedMonitor/PerfCounters"; ap_handler* ap_open() { ap_handler *handler = malloc(sizeof(ap_handler)); handler->ap_file = DEFAULT_AP_FILE; memset(handler, 0, sizeof(ap_handler)); return handler; } void ap_close(ap_handler *handler) { free(handler); } int read_sperator(FILE *fp, int strict) { int c; c = fgetc(fp); //In non-strict mode, Read and discard chars until EOF or FIELD_SEPRATOR while(strict == NON_STRICT_MATCH && c != EOF && c != FIELD_SEPRATOR) { c = fgetc(fp); } if(c == EOF) { return MATCH_EOF; } if(c != FIELD_SEPRATOR) { return MATCH_FAILED; } else { return MATCH_SUCCESS; } } int read_int(FILE *fp, int *val) { int ret = EOF; ret = fscanf(fp, "%d", val); if(ret == EOF) { return MATCH_EOF; } if(ret != 1) { return MATCH_FAILED; } else { return read_sperator(fp, STRICT_MATCH); } } int read_int64(FILE *fp, long long *val) { int ret = EOF; ret = fscanf(fp, "%Ld", val); if(ret == EOF) { return MATCH_EOF; } if(ret != 1) { return MATCH_FAILED; } else { return read_sperator(fp, STRICT_MATCH); } } int read_double(FILE *fp, double *val) { int ret = EOF; ret = fscanf(fp, "%lf", val); if(ret == EOF) { return MATCH_EOF; } if(ret != 1) { return MATCH_FAILED; } else { return read_sperator(fp, STRICT_MATCH); } } int read_str(FILE *fp, char* str, int max_size) { char buf[STR_BUF_MAX]; int c = EOF; int i = 0; if(max_size > STR_BUF_MAX) { return MATCH_FAILED; } memset(buf, 0, STR_BUF_MAX); for(; i < max_size - 1; i++) { c = fgetc(fp); if(c == EOF) { return MATCH_EOF; } if(c == FIELD_SEPRATOR) { break; } buf[i] = c; } strncpy(str, buf, i); if(c == FIELD_SEPRATOR) { return MATCH_SUCCESS; } else//Reaches buf max, discard the rest part of string { return read_sperator(fp, NON_STRICT_MATCH); } } void set_handler_err(ap_handler *handler, int err) { handler->err = err; } int read_pc_from_file(ap_handler* handler, FILE *fp) { int ret = MATCH_FAILED; perf_counter *pc; if(handler->len == PERF_COUNT_MAX) { handler->err = AP_ERR_PC_BUF_OVERFLOW; goto EXIT; } pc = &handler->buf[handler->len]; ret = read_int(fp, &pc->counter_typer); if(ret == MATCH_EOF) { goto EXIT; } if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_COUNTER_TYPE); goto EXIT; } ret = read_str(fp, pc->type_name, TYPE_NAME_MAX); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_TYPE_NAME); goto EXIT; } ret = read_str(fp, pc->property_name, PROPERTY_NAME_MAX); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_PROPERTY_NAME); goto EXIT; } ret = read_str(fp, pc->instance_name, INSTANCE_NAME_MAX); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_INSTANCE_NAME); goto EXIT; } ret = read_int(fp, &pc->is_empty); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_IS_EMPTY_FLAG); goto EXIT; } if(!pc->is_empty) { switch(pc->counter_typer) { case PERF_COUNTER_TYPE_INT: ret = read_int(fp, &pc->val_int); break; case PERF_COUNTER_TYPE_LARGE: ret = read_int64(fp, &pc->val_large); break; case PERF_COUNTER_TYPE_DOUBLE: ret = read_double(fp, &pc->val_double); break; case PERF_COUNTER_TYPE_STRING: ret = read_str(fp, pc->val_str, STRING_VALUE_MAX); break; } if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_VALUE); goto EXIT; } } else { ret = read_sperator(fp, NON_STRICT_MATCH); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_VALUE); goto EXIT; } } ret = read_str(fp, pc->unit_name, UNIT_NAME_MAX); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_UNIT_NAME); goto EXIT; } ret = read_int(fp, &pc->refresh_interval); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_REFRESH_INTERVAL); goto EXIT; } ret = read_int64(fp, &pc->timestamp); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_TIMESTAMP); goto EXIT; } ret = read_str(fp, pc->machine_name, MACHINE_NAME_MAX); if(ret != MATCH_SUCCESS) { set_handler_err(handler, AP_ERR_INVALID_MACHINE_NAME); goto EXIT; } handler->len++; //Discard line end if exits. fscanf(fp, "\n"); EXIT: return ret; } void ap_refresh(ap_handler *handler) { FILE *fp = 0; perf_counter *next = 0; //Reset handler memset(handler->buf, 0, sizeof(perf_counter) * PERF_COUNT_MAX); handler->len = 0; errno = 0; fp = fopen(handler->ap_file, "r"); if(errno || 0 == fp){ handler->err = errno; goto EXIT; } while(read_pc_from_file(handler, fp) != EOF) { if(handler->err != 0) { goto EXIT; } } EXIT: if(fp) { fclose(fp); } } int ap_metric_all(ap_handler *handler, perf_counter *all, size_t size) { int size_to_cp = 0; if(handler->err) { return; } size_to_cp = INTMIN(handler->len, size); if(size_to_cp > 0) { memcpy(all, handler->buf, sizeof(perf_counter) * size_to_cp); } return size_to_cp; } int get_metric(ap_handler *handler, perf_counter *pc, const char *type_name, const char* property_name, size_t size) { int i = 0; int found = 0; for(;i < handler->len && found < size; i++) { if(0 == strcmp(handler->buf[i].type_name, type_name) && 0 == strcmp(handler->buf[i].property_name, property_name)) { memcpy(pc + found, &handler->buf[i], sizeof(perf_counter)); found++; } } if(!found) { handler->err = AP_ERR_PC_NOT_FOUND; } return found; } ================================================ FILE: AzureEnhancedMonitor/clib/test/cases/positive_case ================================================ 2;cpu;Current Hw Frequency;;0;2194.507;MHz;60;1423450780;aem-suse11sp3; 2;cpu;Max Hw Frequency;;0;2194.507;MHz;0;1423450780;aem-suse11sp3; 1;cpu;Current VM Processing Power;;0;1;compute unit;0;1423450780;aem-suse11sp3; 1;cpu;Guaranteed VM Processing Power;;0;1;compute unit;0;1423450780;aem-suse11sp3; 1;cpu;Max. VM Processing Power;;0;1;compute unit;0;1423450780;aem-suse11sp3; 1;cpu;Number of Cores per CPU;;0;1;none;0;1423450780;aem-suse11sp3; 1;cpu;Number of Threads per Core;;0;1;none;0;1423450780;aem-suse11sp3; 2;cpu;Phys. Processing Power per vCPU;;0;1.0;none;0;1423450780;aem-suse11sp3; 4;cpu;Processor Type;;0;Intel(R) Xeon(R) CPU E5-2660 0 @ 2.20GHz, GenuineIntel;none;0;1423450780;aem-suse11sp3; 4;cpu;Reference Compute Unit;;0;Intel(R) Xeon(R) CPU E5-2660 0 @ 2.20GHz, GenuineIntel;none;0;1423450780;aem-suse11sp3; 4;cpu;vCPU Mapping;;0;core;none;0;1423450780;aem-suse11sp3; 2;cpu;VM Processing Power Consumption;;0;1.0;%;60;1423450480;aem-suse11sp3; 1;memory;Current Memory assigned;;0;1681;MB;0;1423450780;aem-suse11sp3; 1;memory;Guaranteed Memory assigned;;0;1681;MB;0;1423450780;aem-suse11sp3; 1;memory;Max Memory assigned;;0;1681;MB;0;1423450780;aem-suse11sp3; 2;memory;VM Memory Consumption;;0;10.0;%;60;1423450480;aem-suse11sp3; 4;network;Adapter Id;eth0;0;eth0;none;0;1423450780;aem-suse11sp3; 4;network;Mapping;eth0;0;00-0d-3a-20-7c-81;none;0;1423450780;aem-suse11sp3; 1;network;Minimum Network Bandwidth;eth0;0;1000;Mbit/s;0;1423450780;aem-suse11sp3; 1;network;Maximum Network Bandwidth;eth0;0;1000;Mbit/s;0;1423450780;aem-suse11sp3; 3;network;Network Read Bytes;;0;60676750;byte/s;0;1423450780;aem-suse11sp3; 3;network;Network Write Bytes;;0;11596695;byte/s;0;1423450780;aem-suse11sp3; 1;network;Packets Retransmitted;;0;279;packets/min;0;1423450780;aem-suse11sp3; 3;config;Last Hardware Change;;0;1423449729;posixtime;0;1423450780;aem-suse11sp3; 4;storage;Phys. Disc to Storage Mapping;/dev/sdb;0;not mapped to vhd;none;0;1423450780;aem-suse11sp3; 4;storage;Phys. Disc to Storage Mapping;/dev/sda;0;portalvhdsz0msmsvh2cnqj aem-suse11sp3-aem-suse11sp3-0-201502071338440211;none;0;1423450780;aem-suse11sp3; 4;storage;Storage ID;portalvhdsz0msmsvh2cnqj;0;portalvhdsz0msmsvh2cnqj;none;0;1423450781;aem-suse11sp3; 3;storage;Storage Read Bytes;portalvhdsz0msmsvh2cnqj;0;424198985;byte;60;1423450781;aem-suse11sp3; 1;storage;Storage Read Ops;portalvhdsz0msmsvh2cnqj;0;2183;none;60;1423450781;aem-suse11sp3; 2;storage;Storage Read Op Latency E2E msec;portalvhdsz0msmsvh2cnqj;0;64.7292721223;ms;60;1423450781;aem-suse11sp3; 2;storage;Storage Read Op Latency Server msec;portalvhdsz0msmsvh2cnqj;0;20.0522214489;ms;60;1423450781;aem-suse11sp3; 2;storage;Storage Read Throughput E2E MB/sec;portalvhdsz0msmsvh2cnqj;0;6.742461284;MB/s;60;1423450781;aem-suse11sp3; 3;storage;Storage Write Bytes;portalvhdsz0msmsvh2cnqj;0;208673771;byte;60;1423450781;aem-suse11sp3; 1;storage;Storage Write Ops;portalvhdsz0msmsvh2cnqj;0;3860;none;60;1423450781;aem-suse11sp3; 2;storage;Storage Write Op Latency E2E msec;portalvhdsz0msmsvh2cnqj;0;14.3150263047;ms;60;1423450781;aem-suse11sp3; 2;storage;Storage Write Op Latency Server msec;portalvhdsz0msmsvh2cnqj;0;14.0740937373;ms;60;1423450781;aem-suse11sp3; 2;storage;Storage Write Throughput E2E MB/sec;portalvhdsz0msmsvh2cnqj;0;3.31678026517;MB/s;60;1423450781;aem-suse11sp3; 4;config;Cloud Provider;;0;Microsoft Azure;none;0;1423450781;aem-suse11sp3; 4;config;CPU Over-Provisioning;;0;no;none;0;1423450781;aem-suse11sp3; 4;config;Memory Over-Provisioning;;0;no;none;0;1423450781;aem-suse11sp3; 4;config;Data Provider Version;;0;1.0.0;none;0;1423450781;aem-suse11sp3; 4;config;Data Sources;;0;lad;none;0;1423450781;aem-suse11sp3; 4;config;Instance Type;;0;Small;none;0;1423450781;aem-suse11sp3; 4;config;Virtualization Solution;;0;Microsoft Hv;none;0;1423450781;aem-suse11sp3; 4;config;Virtualization Solution Version;;0;6.3;none;0;1423450781;aem-suse11sp3; ================================================ FILE: AzureEnhancedMonitor/clib/test/codegen.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import os code_start="""\ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // This file is auto-generated, don't modify it directly. // #include #include """ code_tmpl="""\ int ap_metric_{0}_{1}(ap_handler *handler, perf_counter *pc, size_t size) {{ if(handler->err) {{ return 0; }} return get_metric(handler, pc, "{2}", "{3}", size); }} """ head_tmpl="""\ //{0}\{1} extern int ap_metric_{2}_{3}(ap_handler *handler, perf_counter *pc, size_t size); """ test_root = os.path.dirname(os.path.abspath(__file__)) if __name__ == "__main__": with open(os.path.join(test_root, "counter_names"), "r") as file_in, \ open(os.path.join(test_root, "../src/apmetric.c"), "w") as file_out, \ open(os.path.join(test_root, "../build/metric_def"), "w") as head_out: lines = file_in.read().split("\n") file_out.write(code_start) for line in lines: match = re.match("([^;]*);([^;]*);([^;]*)", line) if match is not None: type_name = match.group(1) prop_name = match.group(2) short_name = match.group(3) short_name = short_name.lower() short_name = short_name.replace(" ", "_") short_name = short_name.replace("-", "_") code_snippet = code_tmpl.format(type_name.lower(), short_name, type_name, prop_name) file_out.write(code_snippet) head_snippet = head_tmpl.format(type_name, prop_name, type_name.lower(), short_name) head_out.write(head_snippet) print("printf(\">>>>ap_metric_{0}_{1}\\n\");".format(type_name, short_name)) print("ap_metric_{0}_{1}(handler, &pc, 1);".format(type_name, short_name)) print("print_counter(&pc);") ================================================ FILE: AzureEnhancedMonitor/clib/test/counter_names ================================================ config;Cloud Provider;Cloud Provider config;CPU Over-Provisioning;CPU Over-Provisioning config;Memory Over-Provisioning;Memory Over-Provisioning config;Data Provider Version;Data Provider Version config;Data Sources;Data Sources config;Instance Type;Instance Type config;Virtualization Solution;Virtualization Solution config;Virtualization Solution Version;Virtualization Solution Version cpu;Current Hw Frequency;Current Hw Frequency cpu;Max Hw Frequency;Max Hw Frequency cpu;Current VM Processing Power;Current VM Processing Power cpu;Guaranteed VM Processing Power;Guaranteed VM Processing Power cpu;Max. VM Processing Power;Max VM Processing Power cpu;Number of Cores per CPU;Number of Cores per CPU cpu;Number of Threads per Core;Number of Threads per Core cpu;Phys. Processing Power per vCPU;Phys Processing Power per vCPU cpu;Processor Type;Processor Type cpu;Reference Compute Unit;Reference Compute Unit cpu;vCPU Mapping;vCPU Mapping cpu;VM Processing Power Consumption;VM Processing Power Consumption memory;Current Memory assigned;Current Memory assigned memory;Guaranteed Memory assigned;Guaranteed Memory assigned memory;Max Memory assigned;Max Memory assigned memory;VM Memory Consumption;VM Memory Consumption network;Adapter Id;Adapter Id network;Mapping;Mapping network;Minimum Network Bandwidth;Min Network Bandwidth network;Maximum Network Bandwidth;Max Network Bandwidth network;Network Read Bytes;Network Read Bytes network;Network Write Bytes;Network Write Bytes network;Packets Retransmitted;Packets Retransmitted config;Last Hardware Change;Last Hardware Change storage;Phys. Disc to Storage Mapping;Phys Disc to Storage Mapping storage;Storage ID;Storage ID storage;Storage Read Bytes;Read Bytes storage;Storage Read Ops;Read Ops storage;Storage Read Op Latency E2E msec;Read Op Latency E2E storage;Storage Read Op Latency Server msec;Read Op Latency Server storage;Storage Read Throughput E2E MB/sec;Read Throughput E2E storage;Storage Write Bytes;Write Bytes storage;Storage Write Ops;Write Ops storage;Storage Write Op Latency E2E msec;Write Op Latency E2E storage;Storage Write Op Latency Server msec;Write Op Latency Server storage;Storage Write Throughput E2E MB/sec;Write Throughput E2E ================================================ FILE: AzureEnhancedMonitor/clib/test/runtest.c ================================================ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include #include #include static const char default_input[] = "./test/cases/positive_case"; int main(int argc, char ** argv) { char* ap_file = (char*) default_input; if(argc == 2) { ap_file = argv[1]; } printf("Parsing perf counters from: %s\n", ap_file); run_test(ap_file); } void print_counter(perf_counter *pc) { printf("%-7s | %-24.24s | %-15.15s | ", pc->type_name, pc->property_name, pc->instance_name); switch(pc->counter_typer) { case PERF_COUNTER_TYPE_INT: printf("%-30d", pc->val_int); break; case PERF_COUNTER_TYPE_LARGE: printf("%-30Ld", pc->val_large); break; case PERF_COUNTER_TYPE_DOUBLE: printf("%-30lf", pc->val_double); break; case PERF_COUNTER_TYPE_STRING: default: printf("%-30.30s", pc->val_str); break; } printf(" |\n"); } int run_test(char* ap_file) { int ret = 0; ap_handler *handler = 0; int i = 0; perf_counter pc; handler = ap_open(); handler->ap_file = ap_file; ap_refresh(handler); if(handler->err) { ret = handler->err; printf("Error code:%d\n", handler->err); goto EXIT; } printf("Found counters:%d\n", handler->len); for(; i < handler->len; i++) { pc = handler->buf[i]; print_counter(&pc); memset(&pc, 0 , sizeof(perf_counter)); } printf(">>>>ap_metric_config_cloud_provider\n"); ap_metric_config_cloud_provider(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_cpu_over_provisioning\n"); ap_metric_config_cpu_over_provisioning(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_memory_over_provisioning\n"); ap_metric_config_memory_over_provisioning(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_data_provider_version\n"); ap_metric_config_data_provider_version(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_data_sources\n"); ap_metric_config_data_sources(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_instance_type\n"); ap_metric_config_instance_type(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_virtualization_solution\n"); ap_metric_config_virtualization_solution(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_virtualization_solution_version\n"); ap_metric_config_virtualization_solution_version(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_current_hw_frequency\n"); ap_metric_cpu_current_hw_frequency(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_max_hw_frequency\n"); ap_metric_cpu_max_hw_frequency(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_current_vm_processing_power\n"); ap_metric_cpu_current_vm_processing_power(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_guaranteed_vm_processing_power\n"); ap_metric_cpu_guaranteed_vm_processing_power(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_max_vm_processing_power\n"); ap_metric_cpu_max_vm_processing_power(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_number_of_cores_per_cpu\n"); ap_metric_cpu_number_of_cores_per_cpu(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_number_of_threads_per_core\n"); ap_metric_cpu_number_of_threads_per_core(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_phys_processing_power_per_vcpu\n"); ap_metric_cpu_phys_processing_power_per_vcpu(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_processor_type\n"); ap_metric_cpu_processor_type(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_reference_compute_unit\n"); ap_metric_cpu_reference_compute_unit(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_vcpu_mapping\n"); ap_metric_cpu_vcpu_mapping(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_cpu_vm_processing_power_consumption\n"); ap_metric_cpu_vm_processing_power_consumption(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_memory_current_memory_assigned\n"); ap_metric_memory_current_memory_assigned(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_memory_guaranteed_memory_assigned\n"); ap_metric_memory_guaranteed_memory_assigned(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_memory_max_memory_assigned\n"); ap_metric_memory_max_memory_assigned(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_memory_vm_memory_consumption\n"); ap_metric_memory_vm_memory_consumption(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_adapter_id\n"); ap_metric_network_adapter_id(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_mapping\n"); ap_metric_network_mapping(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_min_network_bandwidth\n"); ap_metric_network_min_network_bandwidth(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_max_network_bandwidth\n"); ap_metric_network_max_network_bandwidth(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_network_read_bytes\n"); ap_metric_network_network_read_bytes(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_network_write_bytes\n"); ap_metric_network_network_write_bytes(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_network_packets_retransmitted\n"); ap_metric_network_packets_retransmitted(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_config_last_hardware_change\n"); ap_metric_config_last_hardware_change(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_phys_disc_to_storage_mapping\n"); ap_metric_storage_phys_disc_to_storage_mapping(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_storage_id\n"); ap_metric_storage_storage_id(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_read_bytes\n"); ap_metric_storage_read_bytes(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_read_ops\n"); ap_metric_storage_read_ops(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_read_op_latency_e2e\n"); ap_metric_storage_read_op_latency_e2e(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_read_op_latency_server\n"); ap_metric_storage_read_op_latency_server(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_read_throughput_e2e\n"); ap_metric_storage_read_throughput_e2e(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_write_bytes\n"); ap_metric_storage_write_bytes(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_write_ops\n"); ap_metric_storage_write_ops(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_write_op_latency_e2e\n"); ap_metric_storage_write_op_latency_e2e(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_write_op_latency_server\n"); ap_metric_storage_write_op_latency_server(handler, &pc, 1); print_counter(&pc); printf(">>>>ap_metric_storage_write_throughput_e2e\n"); ap_metric_storage_write_throughput_e2e(handler, &pc, 1); print_counter(&pc); EXIT: ap_close(handler); return ret; } ================================================ FILE: AzureEnhancedMonitor/ext/.gitignore ================================================ bin/* .ropeproject/ ================================================ FILE: AzureEnhancedMonitor/ext/HandlerManifest.json ================================================ [{ "name": "AzureEnhancedMonitor", "version": 1.0, "handlerManifest": { "installCommand": "installer.py", "uninstallCommand": "handler.py uninstall", "updateCommand": "handler.py update", "enableCommand": "handler.py enable", "disableCommand": "handler.py disable", "rebootAfterInstall": false, "reportHeartbeat": false } }] ================================================ FILE: AzureEnhancedMonitor/ext/aem.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import socket import traceback import time import datetime import psutil import urlparse import xml.dom.minidom as minidom from azure.storage import TableService, Entity from Utils.WAAgentUtil import waagent, AddExtensionEvent FAILED_TO_RETRIEVE_MDS_DATA="(03100)Failed to retrieve mds data" FAILED_TO_RETRIEVE_LOCAL_DATA="(03101)Failed to retrieve local data" FAILED_TO_RETRIEVE_STORAGE_DATA="(03102)Failed to retrieve storage data" FAILED_TO_SERIALIZE_PERF_COUNTERS="(03103)Failed to serialize perf counters" def timedelta_total_seconds(delta): if not hasattr(datetime.timedelta, 'total_seconds'): return delta.days * 86400 + delta.seconds else: return delta.total_seconds() def get_host_base_from_uri(blob_uri): uri = urlparse.urlparse(blob_uri) netloc = uri.netloc if netloc is None: return None return netloc[netloc.find('.'):] MonitoringIntervalInMinute = 1 #One minute MonitoringInterval = 60 * MonitoringIntervalInMinute #It takes sometime before the performance date reaches azure table. AzureTableDelayInMinute = 5 #Five minute AzureTableDelay = 60 * AzureTableDelayInMinute AzureEnhancedMonitorVersion = "2.0.0" LibDir = "/var/lib/AzureEnhancedMonitor" LatestErrorRecord = "LatestErrorRecord" def clearLastErrorRecord(): errFile = os.path.join(LibDir, LatestErrorRecord) if os.path.exists(errFile) and os.path.isfile(errFile): os.remove(errFile) def getLatestErrorRecord(): errFile=os.path.join(LibDir, LatestErrorRecord) if os.path.exists(errFile) and os.path.isfile(errFile): with open(errFile, 'r') as f: return f.read() return "0" def updateLatestErrorRecord(s): errFile = os.path.join(LibDir, LatestErrorRecord) maxRetry = 3 for i in range(0, maxRetry): try: with open(errFile, "w+") as F: F.write(s.encode("utf8")) return except IOError: time.sleep(1) waagent.Error(("Failed to serialize latest error record to file:" "{0}").format(errFile)) AddExtensionEvent(message="failed to write latest error record") raise def easyHash(s): """ MDSD used the following hash algorithm to cal a first part of partition key """ strHash = 0 multiplier = 37 for c in s: strHash = strHash * multiplier + ord(c) #Only keep the last 64bit, since the mod base is 100 strHash = strHash % (1<<64) return strHash % 100 #Assume eventVolume is Large Epoch = datetime.datetime(1, 1, 1) tickInOneSecond = 1000 * 10000 # 1s = 1000 * 10000 ticks def getMDSTimestamp(unixTimestamp): unixTime = datetime.datetime.utcfromtimestamp(unixTimestamp) startTimestamp = int(timedelta_total_seconds(unixTime - Epoch)) return startTimestamp * tickInOneSecond def getIdentity(): identity = socket.gethostname() return identity def getMDSPartitionKey(identity, timestamp): hashVal = easyHash(identity) return "{0:0>19d}___{1:0>19d}".format(hashVal, timestamp) def getAzureDiagnosticKeyRange(): #Round down by MonitoringInterval endTime = (int(time.time()) / MonitoringInterval) * MonitoringInterval endTime = endTime - AzureTableDelay startTime = endTime - MonitoringInterval identity = getIdentity() startKey = getMDSPartitionKey(identity, getMDSTimestamp(startTime)) endKey = getMDSPartitionKey(identity, getMDSTimestamp(endTime)) return startKey, endKey def getAzureDiagnosticCPUData(accountName, accountKey, hostBase, startKey, endKey, deploymentId): try: waagent.Log("Retrieve diagnostic data(CPU).") table = "LinuxCpuVer2v0" tableService = TableService(account_name = accountName, account_key = accountKey, host_base = hostBase) ofilter = ("PartitionKey ge '{0}' and PartitionKey lt '{1}' " "and DeploymentId eq '{2}'").format(startKey, endKey, deploymentId) oselect = ("PercentProcessorTime,DeploymentId") data = tableService.query_entities(table, ofilter, oselect, 1) if data is None or len(data) == 0: return None cpuPercent = float(data[0].PercentProcessorTime) return cpuPercent except Exception as e: waagent.Error((u"Failed to retrieve diagnostic data(CPU): {0} {1}" "").format(e, traceback.format_exc())) updateLatestErrorRecord(FAILED_TO_RETRIEVE_MDS_DATA) AddExtensionEvent(message=FAILED_TO_RETRIEVE_MDS_DATA) return None def getAzureDiagnosticMemoryData(accountName, accountKey, hostBase, startKey, endKey, deploymentId): try: waagent.Log("Retrieve diagnostic data: Memory") table = "LinuxMemoryVer2v0" tableService = TableService(account_name = accountName, account_key = accountKey, host_base = hostBase) ofilter = ("PartitionKey ge '{0}' and PartitionKey lt '{1}' " "and DeploymentId eq '{2}'").format(startKey, endKey, deploymentId) oselect = ("PercentAvailableMemory,DeploymentId") data = tableService.query_entities(table, ofilter, oselect, 1) if data is None or len(data) == 0: return None memoryPercent = 100 - float(data[0].PercentAvailableMemory) return memoryPercent except Exception as e: waagent.Error((u"Failed to retrieve diagnostic data(Memory): {0} {1}" "").format(e, traceback.format_exc())) updateLatestErrorRecord(FAILED_TO_RETRIEVE_MDS_DATA) AddExtensionEvent(message=FAILED_TO_RETRIEVE_MDS_DATA) return None class AzureDiagnosticData(object): def __init__(self, config): self.config = config accountName = config.getLADName() accountKey = config.getLADKey() hostBase = config.getLADHostBase() hostname = socket.gethostname() deploymentId = config.getVmDeploymentId() startKey, endKey = getAzureDiagnosticKeyRange() self.cpuPercent = getAzureDiagnosticCPUData(accountName, accountKey, hostBase, startKey, endKey, deploymentId) self.memoryPercent = getAzureDiagnosticMemoryData(accountName, accountKey, hostBase, startKey, endKey, deploymentId) def getCPUPercent(self): return self.cpuPercent def getMemoryPercent(self): return self.memoryPercent class AzureDiagnosticMetric(object): def __init__(self, config): self.config = config self.linux = LinuxMetric(self.config) self.azure = AzureDiagnosticData(self.config) self.timestamp = int(time.time()) - AzureTableDelay def getTimestamp(self): return self.timestamp def getCurrHwFrequency(self): return self.linux.getCurrHwFrequency() def getMaxHwFrequency(self): return self.linux.getMaxHwFrequency() def getCurrVMProcessingPower(self): return self.linux.getCurrVMProcessingPower() def getGuaranteedVMProcessingPower(self): return self.linux.getGuaranteedVMProcessingPower() def getMaxVMProcessingPower(self): return self.linux.getMaxVMProcessingPower() def getNumOfCoresPerCPU(self): return self.linux.getNumOfCoresPerCPU() def getNumOfThreadsPerCore(self): return self.linux.getNumOfThreadsPerCore() def getPhysProcessingPowerPerVCPU(self): return self.linux.getPhysProcessingPowerPerVCPU() def getProcessorType(self): return self.linux.getProcessorType() def getReferenceComputeUnit(self): return self.linux.getReferenceComputeUnit() def getVCPUMapping(self): return self.linux.getVCPUMapping() def getVMProcessingPowerConsumption(self): return self.azure.getCPUPercent() def getCurrMemAssigned(self): return self.linux.getCurrMemAssigned() def getGuaranteedMemAssigned(self): return self.linux.getGuaranteedMemAssigned() def getMaxMemAssigned(self): return self.linux.getMaxMemAssigned() def getVMMemConsumption(self): return self.azure.getMemoryPercent() def getNetworkAdapterIds(self): return self.linux.getNetworkAdapterIds() def getNetworkAdapterMapping(self, adapterId): return self.linux.getNetworkAdapterMapping(adapterId) def getMaxNetworkBandwidth(self, adapterId): return self.linux.getMaxNetworkBandwidth(adapterId) def getMinNetworkBandwidth(self, adapterId): return self.linux.getMinNetworkBandwidth(adapterId) def getNetworkReadBytes(self, adapterId): return self.linux.getNetworkReadBytes(adapterId) def getNetworkWriteBytes(self, adapterId): return self.linux.getNetworkWriteBytes(adapterId) def getNetworkPacketRetransmitted(self): return self.linux.getNetworkPacketRetransmitted() def getLastHardwareChange(self): return self.linux.getLastHardwareChange() class CPUInfo(object): @staticmethod def getCPUInfo(): cpuinfo = waagent.GetFileContents("/proc/cpuinfo") ret, lscpu = waagent.RunGetOutput("lscpu") return CPUInfo(cpuinfo, lscpu) def __init__(self, cpuinfo, lscpu): self.cpuinfo = cpuinfo self.lscpu = lscpu self.cores = 1; self.coresPerCpu = 1; self.threadsPerCore = 1; coresMatch = re.search("CPU(s):\s+(\d+)", self.lscpu) if coresMatch: self.cores = int(coresMatch.group(1)) coresPerCpuMatch = re.search("Core(s) per socket:\s+(\d+)", self.lscpu) if coresPerCpuMatch: self.coresPerCpu = int(coresPerCpuMatch.group(1)) threadsPerCoreMatch = re.search("Core(s) per socket:\s+(\d+)", self.lscpu) if threadsPerCoreMatch: self.threadsPerCore = int(threadsPerCoreMatch.group(1)) model = re.search("model name\s+:\s+(.*)\s", self.cpuinfo) vendorId = re.search("vendor_id\s+:\s+(.*)\s", self.cpuinfo) if model and vendorId: self.processorType = "{0}, {1}".format(model.group(1), vendorId.group(1)) else: self.processorType = None freqMatch = re.search("CPU MHz:\s+(.*)\s", self.lscpu) if freqMatch: self.frequency = float(freqMatch.group(1)) else: self.frequency = None ht = re.match("flags\s.*\sht\s", self.cpuinfo) self.isHTon = ht is not None def getNumOfCoresPerCPU(self): return self.coresPerCpu def getNumOfCores(self): return self.cores def getNumOfThreadsPerCore(self): return self.threadsPerCore def getProcessorType(self): return self.processorType def getFrequency(self): return self.frequency def isHyperThreadingOn(self): return self.isHTon def getCPUPercent(self): return psutil.cpu_percent() class MemoryInfo(object): def __init__(self): self.memInfo = psutil.virtual_memory() def getMemSize(self): return self.memInfo[0] / 1024 / 1024 #MB def getMemPercent(self): return self.memInfo[2] #% def getMacAddress(adapterId): nicAddrPath = os.path.join("/sys/class/net", adapterId, "address") mac = waagent.GetFileContents(nicAddrPath) mac = mac.strip() mac = mac.replace(":", "-") return mac def sameList(l1, l2): if l1 is None or l2 is None: return l1 == l2 if len(l1) != len(l2): return False for i in range(0, len(l1)): if l1[i] != l2[i]: return False return True class NetworkInfo(object): def __init__(self): self.nics = psutil.net_io_counters(pernic=True) self.nicNames = [] for nicName, stat in self.nics.iteritems(): if nicName != 'lo': self.nicNames.append(nicName) def getAdapterIds(self): return self.nicNames def getNetworkReadBytes(self, adapterId): net = psutil.net_io_counters(pernic=True) if net[adapterId] != None: bytes_recv1 = net[adapterId][1] time1 = time.time() time.sleep(0.2) net = psutil.net_io_counters(pernic=True) bytes_recv2 = net[adapterId][1] time2 = time.time() interval = (time2 - time1) return (bytes_recv2 - bytes_recv1) / interval else: return 0 def getNetworkWriteBytes(self, adapterId): net = psutil.net_io_counters(pernic=True) if net[adapterId] != None: bytes_sent1 = net[adapterId][0] time1 = time.time() time.sleep(0.2) net = psutil.net_io_counters(pernic=True) bytes_sent2 = net[adapterId][0] time2 = time.time() interval = (time2 - time1) return (bytes_sent2 - bytes_sent1) / interval else: return 0 def getNetstat(self): retCode, output = waagent.RunGetOutput("netstat -s", chk_err=False) return output def getNetworkPacketRetransmitted(self): netstat = self.getNetstat() match = re.search("(\d+)\s*segments retransmited", netstat) if match != None: return int(match.group(1)) else: waagent.Error("Failed to parse netstat output: {0}".format(netstat)) updateLatestErrorRecord(FAILED_TO_RETRIEVE_LOCAL_DATA) AddExtensionEvent(message=FAILED_TO_RETRIEVE_LOCAL_DATA) return None HwInfoFile = os.path.join(LibDir, "HwInfo") class HardwareChangeInfo(object): def __init__(self, networkInfo): self.networkInfo = networkInfo def getHwInfo(self): if not os.path.isfile(HwInfoFile): return None, None hwInfo = waagent.GetFileContents(HwInfoFile).split("\n") return int(hwInfo[0]), hwInfo[1:] def setHwInfo(self, timestamp, hwInfo): content = str(timestamp) content = content + "\n" + "\n".join(hwInfo) waagent.SetFileContents(HwInfoFile, content) def getLastHardwareChange(self): oldTime, oldMacs = self.getHwInfo() newMacs = map(lambda x : getMacAddress(x), self.networkInfo.getAdapterIds()) newTime = int(time.time()) newMacs.sort() if oldMacs is None or not sameList(newMacs, oldMacs): #Hardware changed if newTime < oldTime: waagent.Warn(("Hardware change detected. But the old timestamp " "is greater than now, {0}>{1}.").format(oldTime, newTime)) self.setHwInfo(newTime, newMacs) return newTime else: return oldTime class LinuxMetric(object): def __init__(self, config): self.config = config #CPU self.cpuInfo = CPUInfo.getCPUInfo() #Memory self.memInfo = MemoryInfo() #Network self.networkInfo = NetworkInfo() #Detect hardware change self.hwChangeInfo = HardwareChangeInfo(self.networkInfo) self.timestamp = int(time.time()) def getTimestamp(self): return self.timestamp def getCurrHwFrequency(self): return self.cpuInfo.getFrequency() def getMaxHwFrequency(self): return self.getCurrHwFrequency() def getCurrVMProcessingPower(self): if self.config.isCpuOverCommitted(): return None else: return self.cpuInfo.getNumOfCores() def getGuaranteedVMProcessingPower(self): return self.getCurrVMProcessingPower() def getMaxVMProcessingPower(self): return self.getCurrVMProcessingPower() def getNumOfCoresPerCPU(self): return self.cpuInfo.getNumOfCoresPerCPU() def getNumOfThreadsPerCore(self): return self.cpuInfo.getNumOfThreadsPerCore() def getPhysProcessingPowerPerVCPU(self): return 1 / float(self.getNumOfThreadsPerCore()) def getProcessorType(self): return self.cpuInfo.getProcessorType() def getReferenceComputeUnit(self): return self.getProcessorType() def getVCPUMapping(self): return "thread" if self.cpuInfo.isHyperThreadingOn() else "core" def getVMProcessingPowerConsumption(self): return self.memInfo.getMemPercent() def getCurrMemAssigned(self): if self.config.isMemoryOverCommitted(): return None else: return self.memInfo.getMemSize() def getGuaranteedMemAssigned(self): return self.getCurrMemAssigned() def getMaxMemAssigned(self): return self.getCurrMemAssigned() def getVMMemConsumption(self): return self.memInfo.getMemPercent() def getNetworkAdapterIds(self): return self.networkInfo.getAdapterIds() def getNetworkAdapterMapping(self, adapterId): return getMacAddress(adapterId) def getMaxNetworkBandwidth(self, adapterId): return 1000 #Mbit/s def getMinNetworkBandwidth(self, adapterId): return 1000 #Mbit/s def getNetworkReadBytes(self, adapterId): return self.networkInfo.getNetworkReadBytes(adapterId) def getNetworkWriteBytes(self, adapterId): return self.networkInfo.getNetworkWriteBytes(adapterId) def getNetworkPacketRetransmitted(self): return self.networkInfo.getNetworkPacketRetransmitted() def getLastHardwareChange(self): return self.hwChangeInfo.getLastHardwareChange() class VMDataSource(object): def __init__(self, config): self.config = config def collect(self): counters = [] if self.config.isLADEnabled(): metrics = AzureDiagnosticMetric(self.config) else: metrics = LinuxMetric(self.config) #CPU counters.append(self.createCounterCurrHwFrequency(metrics)) counters.append(self.createCounterMaxHwFrequency(metrics)) counters.append(self.createCounterCurrVMProcessingPower(metrics)) counters.append(self.createCounterGuaranteedVMProcessingPower(metrics)) counters.append(self.createCounterMaxVMProcessingPower(metrics)) counters.append(self.createCounterNumOfCoresPerCPU(metrics)) counters.append(self.createCounterNumOfThreadsPerCore(metrics)) counters.append(self.createCounterPhysProcessingPowerPerVCPU(metrics)) counters.append(self.createCounterProcessorType(metrics)) counters.append(self.createCounterReferenceComputeUnit(metrics)) counters.append(self.createCounterVCPUMapping(metrics)) counters.append(self.createCounterVMProcessingPowerConsumption(metrics)) #Memory counters.append(self.createCounterCurrMemAssigned(metrics)) counters.append(self.createCounterGuaranteedMemAssigned(metrics)) counters.append(self.createCounterMaxMemAssigned(metrics)) counters.append(self.createCounterVMMemConsumption(metrics)) #Network adapterIds = metrics.getNetworkAdapterIds() for adapterId in adapterIds: if adapterId.startswith('eth'): counters.append(self.createCounterAdapterId(adapterId)) counters.append(self.createCounterNetworkMapping(metrics, adapterId)) counters.append(self.createCounterMinNetworkBandwidth(metrics, adapterId)) counters.append(self.createCounterMaxNetworkBandwidth(metrics, adapterId)) counters.append(self.createCounterNetworkReadBytes(metrics, adapterId)) counters.append(self.createCounterNetworkWriteBytes(metrics, adapterId)) counters.append(self.createCounterNetworkPacketRetransmitted(metrics)) #Hardware change counters.append(self.createCounterLastHardwareChange(metrics)) #Error counters.append(self.createCounterError()) return counters def createCounterLastHardwareChange(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "config", name = "Last Hardware Change", value = metrics.getLastHardwareChange(), unit="posixtime") def createCounterError(self): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "config", name = "Error", value = getLatestErrorRecord()) def createCounterCurrHwFrequency(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "cpu", name = "Current Hw Frequency", value = metrics.getCurrHwFrequency(), unit = "MHz", refreshInterval = 60) def createCounterMaxHwFrequency(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "cpu", name = "Max Hw Frequency", value = metrics.getMaxHwFrequency(), unit = "MHz") def createCounterCurrVMProcessingPower(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "cpu", name = "Current VM Processing Power", value = metrics.getCurrVMProcessingPower(), unit = "compute unit") def createCounterMaxVMProcessingPower(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "cpu", name = "Max. VM Processing Power", value = metrics.getMaxVMProcessingPower(), unit = "compute unit") def createCounterGuaranteedVMProcessingPower(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "cpu", name = "Guaranteed VM Processing Power", value = metrics.getGuaranteedVMProcessingPower(), unit = "compute unit") def createCounterNumOfCoresPerCPU(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "cpu", name = "Number of Cores per CPU", value = metrics.getNumOfCoresPerCPU()) def createCounterNumOfThreadsPerCore(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "cpu", name = "Number of Threads per Core", value = metrics.getNumOfThreadsPerCore()) def createCounterPhysProcessingPowerPerVCPU(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "cpu", name = "Phys. Processing Power per vCPU", value = metrics.getPhysProcessingPowerPerVCPU()) def createCounterProcessorType(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "cpu", name = "Processor Type", value = metrics.getProcessorType()) def createCounterReferenceComputeUnit(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "cpu", name = "Reference Compute Unit", value = metrics.getReferenceComputeUnit()) def createCounterVCPUMapping(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "cpu", name = "vCPU Mapping", value = metrics.getVCPUMapping()) def createCounterVMProcessingPowerConsumption(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "cpu", name = "VM Processing Power Consumption", value = metrics.getVMProcessingPowerConsumption(), unit = "%", timestamp = metrics.getTimestamp(), refreshInterval = 60) def createCounterCurrMemAssigned(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "memory", name = "Current Memory assigned", value = metrics.getCurrMemAssigned(), unit = "MB") def createCounterMaxMemAssigned(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "memory", name = "Max Memory assigned", value = metrics.getMaxMemAssigned(), unit = "MB") def createCounterGuaranteedMemAssigned(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "memory", name = "Guaranteed Memory assigned", value = metrics.getGuaranteedMemAssigned(), unit = "MB") def createCounterVMMemConsumption(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "memory", name = "VM Memory Consumption", value = metrics.getVMMemConsumption(), unit = "%", timestamp = metrics.getTimestamp(), refreshInterval = 60) def createCounterAdapterId(self, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "network", name = "Adapter Id", instance = adapterId, value = adapterId) def createCounterNetworkMapping(self, metrics, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "network", name = "Mapping", instance = adapterId, value = metrics.getNetworkAdapterMapping(adapterId)) def createCounterMaxNetworkBandwidth(self, metrics, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "network", name = "VM Maximum Network Bandwidth", instance = adapterId, value = metrics.getMaxNetworkBandwidth(adapterId), unit = "Mbit/s") def createCounterMinNetworkBandwidth(self, metrics, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "network", name = "VM Minimum Network Bandwidth", instance = adapterId, value = metrics.getMinNetworkBandwidth(adapterId), unit = "Mbit/s") def createCounterNetworkReadBytes(self, metrics, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "network", name = "Network Read Bytes", instance = adapterId, value = metrics.getNetworkReadBytes(adapterId), unit = "byte/s") def createCounterNetworkWriteBytes(self, metrics, adapterId): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "network", name = "Network Write Bytes", instance = adapterId, value = metrics.getNetworkWriteBytes(adapterId), unit = "byte/s") def createCounterNetworkPacketRetransmitted(self, metrics): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "network", name = "Packets Retransmitted", value = metrics.getNetworkPacketRetransmitted(), unit = "packets/min") def getStorageTimestamp(unixTimestamp): tformat = "{0:0>4d}{1:0>2d}{2:0>2d}T{3:0>2d}{4:0>2d}" ts = time.gmtime(unixTimestamp) return tformat.format(ts.tm_year, ts.tm_mon, ts.tm_mday, ts.tm_hour, ts.tm_min) def getStorageTableKeyRange(): #Round down by MonitoringInterval endTime = int(time.time()) / MonitoringInterval * MonitoringInterval endTime = endTime - AzureTableDelay startTime = endTime - MonitoringInterval return getStorageTimestamp(startTime), getStorageTimestamp(endTime) def getStorageMetrics(account, key, hostBase, table, startKey, endKey): try: waagent.Log("Retrieve storage metrics data.") tableService = TableService(account_name = account, account_key = key, host_base = hostBase) ofilter = ("PartitionKey ge '{0}' and PartitionKey lt '{1}'" "").format(startKey, endKey) oselect = ("TotalRequests,TotalIngress,TotalEgress,AverageE2ELatency," "AverageServerLatency,RowKey") metrics = tableService.query_entities(table, ofilter, oselect) waagent.Log("{0} records returned.".format(len(metrics))) return metrics except Exception as e: waagent.Error((u"Failed to retrieve storage metrics data: {0} {1}" "").format(e, traceback.format_exc())) updateLatestErrorRecord(FAILED_TO_RETRIEVE_STORAGE_DATA) AddExtensionEvent(message=FAILED_TO_RETRIEVE_STORAGE_DATA) return None def getDataDisks(): blockDevs = os.listdir('/sys/block') dataDisks = filter(lambda d : re.match("sd[c-z]", d), blockDevs) return dataDisks def getFirstLun(dev): path = os.path.join("/sys/block", dev, "device/scsi_disk") for lun in os.listdir(path): return int(lun[-1]) class DiskInfo(object): def __init__(self, config): self.config = config def getDiskMapping(self): osdiskVhd = "{0} {1}".format(self.config.getOSDiskAccount(), self.config.getOSDiskName()) osdisk = { "vhd":osdiskVhd, "type": self.config.getOSDiskType(), "caching": self.config.getOSDiskCaching(), "iops": self.config.getOSDiskSLAIOPS(), "throughput": self.config.getOSDiskSLAThroughput(), } diskMapping = { "/dev/sda": osdisk, } dataDisks = getDataDisks() if dataDisks is None or len(dataDisks) == 0: return diskMapping lunToDevMap = {} for dev in dataDisks: lun = getFirstLun(dev) lunToDevMap[lun] = dev diskCount = self.config.getDataDiskCount() for i in range(0, diskCount): lun = self.config.getDataDiskLun(i) datadiskVhd = "{0} {1}".format(self.config.getDataDiskAccount(i), self.config.getDataDiskName(i)) datadisk = { "vhd": datadiskVhd, "type": self.config.getDataDiskType(i), "caching": self.config.getDataDiskCaching(i), "iops": self.config.getDataDiskSLAIOPS(i), "throughput": self.config.getDataDiskSLAThroughput(i), } if lun in lunToDevMap: dev = lunToDevMap[lun] diskMapping[dev] = datadisk else: waagent.Warn("Couldn't find disk with lun: {0}".format(lun)) return diskMapping def isUserRead(op): if not op.startswith("user;"): return False op = op[5:] for prefix in ["Get", "List", "Preflight"]: if op.startswith(prefix): return True return False def isUserWrite(op): if not op.startswith("user;"): return False op = op[5:] for prefix in ["Put" ,"Set" ,"Clear" ,"Delete" ,"Create" ,"Snapshot"]: if op.startswith(prefix): return True return False def storageStat(metrics, opFilter): stat = {} stat['bytes'] = None stat['ops'] = None stat['e2eLatency'] = None stat['serverLatency'] = None stat['throughput'] = None if metrics is None: return stat metrics = filter(lambda x : opFilter(x.RowKey), metrics) stat['bytes'] = sum(map(lambda x : x.TotalIngress + x.TotalEgress, metrics)) stat['ops'] = sum(map(lambda x : x.TotalRequests, metrics)) if stat['ops'] != 0: stat['e2eLatency'] = sum(map(lambda x : x.TotalRequests * \ x.AverageE2ELatency, metrics)) / stat['ops'] stat['serverLatency'] = sum(map(lambda x : x.TotalRequests * \ x.AverageServerLatency, metrics)) / stat['ops'] #Convert to MB/s stat['throughput'] = float(stat['bytes']) / (1024 * 1024) / 60 return stat class AzureStorageStat(object): def __init__(self, metrics): self.metrics = metrics self.rStat = storageStat(metrics, isUserRead) self.wStat = storageStat(metrics, isUserWrite) def getReadBytes(self): return self.rStat['bytes'] def getReadOps(self): return self.rStat['ops'] def getReadOpE2ELatency(self): return self.rStat['e2eLatency'] def getReadOpServerLatency(self): return self.rStat['serverLatency'] def getReadOpThroughput(self): return self.rStat['throughput'] def getWriteBytes(self): return self.wStat['bytes'] def getWriteOps(self): return self.wStat['ops'] def getWriteOpE2ELatency(self): return self.wStat['e2eLatency'] def getWriteOpServerLatency(self): return self.wStat['serverLatency'] def getWriteOpThroughput(self): return self.wStat['throughput'] class StorageDataSource(object): def __init__(self, config): self.config = config def collect(self): counters = [] #Add disk mapping for resource disk counters.append(self.createCounterDiskMapping("/dev/sdb", "not mapped to vhd")) #Add disk mapping for osdisk and data disk diskMapping = DiskInfo(self.config).getDiskMapping() for dev, disk in diskMapping.iteritems(): counters.append(self.createCounterDiskMapping(dev, disk.get("vhd"))) counters.append(self.createCounterDiskType(dev, disk.get("type"))) counters.append(self.createCounterDiskCaching(dev, disk.get("caching"))) if disk.get("type") == "Premium": counters.append(self.createCounterDiskIOPS(dev, disk.get("iops"))) counters.append(self.createCounterDiskThroughput(dev, disk.get("throughput"))) accounts = self.config.getStorageAccountNames() for account in accounts: if self.config.getStorageAccountType(account) == "Standard": counters.extend(self.collectMetrixForStandardStorage(account)) return counters def collectMetrixForStandardStorage(self, account): counters = [] startKey, endKey = getStorageTableKeyRange() tableName = self.config.getStorageAccountMinuteTable(account) accountKey = self.config.getStorageAccountKey(account) hostBase = self.config.getStorageHostBase(account) metrics = getStorageMetrics(account, accountKey, hostBase, tableName, startKey, endKey) stat = AzureStorageStat(metrics) counters.append(self.createCounterStorageId(account)) counters.append(self.createCounterReadBytes(account, stat)) counters.append(self.createCounterReadOps(account, stat)) counters.append(self.createCounterReadOpE2ELatency(account, stat)) counters.append(self.createCounterReadOpServerLatency(account, stat)) counters.append(self.createCounterReadOpThroughput(account, stat)) counters.append(self.createCounterWriteBytes(account, stat)) counters.append(self.createCounterWriteOps(account, stat)) counters.append(self.createCounterWriteOpE2ELatency(account, stat)) counters.append(self.createCounterWriteOpServerLatency(account, stat)) counters.append(self.createCounterWriteOpThroughput(account, stat)) return counters def createCounterDiskType(self, dev, diskType): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "disk", name = "Storage Type", instance = dev, value = diskType) def createCounterDiskCaching(self, dev, caching): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "disk", name = "Caching", instance = dev, value = caching) def createCounterDiskThroughput(self, dev, throughput): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "disk", name = "SLA Throughput", instance = dev, unit = "MB/sec", value = throughput) def createCounterDiskIOPS(self, dev, iops): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "disk", name = "SLA", instance = dev, unit = "Ops/sec", value = iops) def createCounterReadBytes(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "storage", name = "Storage Read Bytes", instance = account, value = stat.getReadBytes(), unit = 'byte', refreshInterval = 60) def createCounterReadOps(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "storage", name = "Storage Read Ops", instance = account, value = stat.getReadOps(), refreshInterval = 60) def createCounterReadOpE2ELatency(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Read Op Latency E2E msec", instance = account, value = stat.getReadOpE2ELatency(), unit = 'ms', refreshInterval = 60) def createCounterReadOpServerLatency(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Read Op Latency Server msec", instance = account, value = stat.getReadOpServerLatency(), unit = 'ms', refreshInterval = 60) def createCounterReadOpThroughput(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Read Throughput E2E MB/sec", instance = account, value = stat.getReadOpThroughput(), unit = 'MB/s', refreshInterval = 60) def createCounterWriteBytes(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_LARGE, category = "storage", name = "Storage Write Bytes", instance = account, value = stat.getWriteBytes(), unit = 'byte', refreshInterval = 60) def createCounterWriteOps(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "storage", name = "Storage Write Ops", instance = account, value = stat.getWriteOps(), refreshInterval = 60) def createCounterWriteOpE2ELatency(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Write Op Latency E2E msec", instance = account, value = stat.getWriteOpE2ELatency(), unit = 'ms', refreshInterval = 60) def createCounterWriteOpServerLatency(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Write Op Latency Server msec", instance = account, value = stat.getWriteOpServerLatency(), unit = 'ms', refreshInterval = 60) def createCounterWriteOpThroughput(self, account, stat): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_DOUBLE, category = "storage", name = "Storage Write Throughput E2E MB/sec", instance = account, value = stat.getWriteOpThroughput(), unit = 'MB/s', refreshInterval = 60) def createCounterStorageId(self, account): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "storage", name = "Storage ID", instance = account, value = account) def createCounterDiskMapping(self, dev, vhd): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "storage", name = "Phys. Disc to Storage Mapping", instance = dev, value = vhd) class HvInfo(object): def __init__(self): self.hvName = None; self.hvVersion = None; root_dir = os.path.dirname(__file__) cmd = os.path.join(root_dir, "bin/hvinfo") ret, output = waagent.RunGetOutput(cmd, chk_err=False) print(ret) if ret ==0 and output is not None: lines = output.split("\n") if len(lines) >= 2: self.hvName = lines[0] self.hvVersion = lines[1] def getHvName(self): return self.hvName def getHvVersion(self): return self.hvVersion class StaticDataSource(object): def __init__(self, config): self.config = config def collect(self): counters = [] hvInfo = HvInfo() counters.append(self.createCounterCloudProvider()) counters.append(self.createCounterCpuOverCommitted()) counters.append(self.createCounterMemoryOverCommitted()) counters.append(self.createCounterDataProviderVersion()) counters.append(self.createCounterDataSources()) counters.append(self.createCounterInstanceType()) counters.append(self.createCounterVirtSln(hvInfo.getHvName())) counters.append(self.createCounterVirtSlnVersion(hvInfo.getHvVersion())) vmSLAThroughput = self.config.getVMSLAThroughput() if vmSLAThroughput is not None: counters.append(self.createCounterVMSLAThroughput(vmSLAThroughput)) vmSLAIOPS = self.config.getVMSLAIOPS() if vmSLAIOPS is not None: counters.append(self.createCounterVMSLAIOPS(vmSLAIOPS)) return counters def createCounterVMSLAThroughput(self, throughput): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "config", name = "SLA Max Disk Bandwidth per VM", unit = "Ops/sec", value = throughput) def createCounterVMSLAIOPS(self, iops): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_INT, category = "config", name = "SLA Max Disk IOPS per VM", unit = "Ops/sec", value = iops) def createCounterCloudProvider(self): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Cloud Provider", value = "Microsoft Azure") def createCounterVirtSlnVersion(self, hvVersion): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Virtualization Solution Version", value = hvVersion) def createCounterVirtSln(self, hvName): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Virtualization Solution", value = hvName) def createCounterInstanceType(self): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Instance Type", value = self.config.getVmSize()) def createCounterDataSources(self): dataSource = "wad" if self.config.isLADEnabled() else "local" return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Data Sources", value = dataSource) def createCounterDataProviderVersion(self): return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Data Provider Version", value = AzureEnhancedMonitorVersion) def createCounterMemoryOverCommitted(self): value = "yes" if self.config.isMemoryOverCommitted() else "no" return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "Memory Over-Provisioning", value = value) def createCounterCpuOverCommitted(self): value = "yes" if self.config.isCpuOverCommitted() else "no" return PerfCounter(counterType = PerfCounterType.COUNTER_TYPE_STRING, category = "config", name = "CPU Over-Provisioning", value = value) class PerfCounterType(object): COUNTER_TYPE_INVALID = 0 COUNTER_TYPE_INT = 1 COUNTER_TYPE_DOUBLE = 2 COUNTER_TYPE_LARGE = 3 COUNTER_TYPE_STRING = 4 class PerfCounter(object): def __init__(self, counterType, category, name, value, instance="", unit="none", timestamp = None, refreshInterval=0): self.counterType = counterType self.category = category self.name = name self.instance = instance self.value = value self.unit = unit self.refreshInterval = refreshInterval if(timestamp): self.timestamp = timestamp else: self.timestamp = int(time.time()) self.machine = socket.gethostname() def __str__(self): return (u"{0};{1};{2};{3};{4};{5};{6};{7};{8};{9};\n" "").format(self.counterType, self.category, self.name, self.instance, 0 if self.value is not None else 1, self.value if self.value is not None else "", self.unit, self.refreshInterval, self.timestamp, self.machine) __repr__ = __str__ class EnhancedMonitor(object): def __init__(self, config): self.dataSources = [] self.dataSources.append(VMDataSource(config)) self.dataSources.append(StorageDataSource(config)) self.dataSources.append(StaticDataSource(config)) self.writer = PerfCounterWriter() def run(self): counters = [] for dataSource in self.dataSources: counters.extend(dataSource.collect()) clearLastErrorRecord() self.writer.write(counters) EventFile=os.path.join(LibDir, "PerfCounters") class PerfCounterWriter(object): def write(self, counters, maxRetry = 3, eventFile=EventFile): for i in range(0, maxRetry): try: self._write(counters, eventFile) waagent.Log(("Write {0} counters to event file." "").format(len(counters))) return except IOError as e: waagent.Warn((u"Write to perf counters file failed: {0}" "").format(e)) waagent.Log("Retry: {0}".format(i)) time.sleep(1) waagent.Error(("Failed to serialize perf counter to file:" "{0}").format(eventFile)) updateLatestErrorRecord(FAILED_TO_SERIALIZE_PERF_COUNTERS) AddExtensionEvent(message=FAILED_TO_SERIALIZE_PERF_COUNTERS) raise def _write(self, counters, eventFile): with open(eventFile, "w+") as F: F.write("".join(map(lambda c : str(c), counters)).encode("utf8")) class EnhancedMonitorConfig(object): def __init__(self, publicConfig, privateConfig): xmldoc = minidom.parse('/var/lib/waagent/SharedConfig.xml') self.deployment = xmldoc.getElementsByTagName('Deployment') self.role = xmldoc.getElementsByTagName('Role') self.configData = {} diskCount = 0 accountNames = [] for item in publicConfig["cfg"]: self.configData[item["key"]] = item["value"] if item["key"].startswith("disk.lun"): diskCount = diskCount + 1 if item["key"].endswith("minute.name"): accountNames.append(item["value"]) for item in privateConfig["cfg"]: self.configData[item["key"]] = item["value"] self.configData["disk.count"] = diskCount self.configData["account.names"] = accountNames def getVmSize(self): return self.configData.get("vmsize") def getVmRoleInstance(self): return self.role[0].attributes['name'].value def getVmDeploymentId(self): return self.deployment[0].attributes['name'].value def isMemoryOverCommitted(self): return self.configData.get("vm.memory.isovercommitted") def isCpuOverCommitted(self): return self.configData.get("vm.cpu.isovercommitted") def getScriptVersion(self): return self.configData.get("script.version") def isVerbose(self): flag = self.configData.get("verbose") return flag == "1" or flag == 1 def getVMSLAIOPS(self): return self.configData.get("vm.sla.iops") def getVMSLAThroughput(self): return self.configData.get("vm.sla.throughput") def getOSDiskName(self): return self.configData.get("osdisk.name") def getOSDiskAccount(self): osdiskConnMinute = self.getOSDiskConnMinute() return self.configData.get("{0}.name".format(osdiskConnMinute)) def getOSDiskConnMinute(self): return self.configData.get("osdisk.connminute") def getOSDiskConnHour(self): return self.configData.get("osdisk.connhour") def getOSDiskType(self): return self.configData.get("osdisk.type") def getOSDiskCaching(self): return self.configData.get("osdisk.caching") def getOSDiskSLAIOPS(self): return self.configData.get("osdisk.sla.iops") def getOSDiskSLAThroughput(self): return self.configData.get("osdisk.sla.throughput") def getDataDiskCount(self): return self.configData.get("disk.count") def getDataDiskLun(self, index): return self.configData.get("disk.lun.{0}".format(index)) def getDataDiskName(self, index): return self.configData.get("disk.name.{0}".format(index)) def getDataDiskAccount(self, index): return self.configData.get("disk.account.{0}".format(index)) def getDataDiskConnMinute(self, index): return self.configData.get("disk.connminute.{0}".format(index)) def getDataDiskConnHour(self, index): return self.configData.get("disk.connhour.{0}".format(index)) def getDataDiskType(self, index): return self.configData.get("disk.type.{0}".format(index)) def getDataDiskCaching(self, index): return self.configData.get("disk.caching.{0}".format(index)) def getDataDiskSLAIOPS(self, index): return self.configData.get("disk.sla.iops.{0}".format(index)) def getDataDiskSLAThroughput(self, index): return self.configData.get("disk.sla.throughput.{0}".format(index)) def getStorageAccountNames(self): return self.configData.get("account.names") def getStorageAccountKey(self, name): return self.configData.get("{0}.minute.key".format(name)) def getStorageAccountType(self, name): key = "{0}.minute.ispremium".format(name) return "Premium" if self.configData.get(key) == 1 else "Standard" def getStorageHostBase(self, name): return get_host_base_from_uri(self.getStorageAccountMinuteUri(name)) def getStorageAccountMinuteUri(self, name): return self.configData.get("{0}.minute.uri".format(name)) def getStorageAccountMinuteTable(self, name): uri = self.getStorageAccountMinuteUri(name) pos = uri.rfind('/') tableName = uri[pos+1:] return tableName def getStorageAccountHourUri(self, name): return self.configData.get("{0}.hour.uri".format(name)) def isLADEnabled(self): flag = self.configData.get("wad.isenabled") return flag == "1" or flag == 1 def getLADKey(self): return self.configData.get("wad.key") def getLADName(self): return self.configData.get("wad.name") def getLADHostBase(self): return get_host_base_from_uri(self.getLADUri()) def getLADUri(self): return self.configData.get("wad.uri") ================================================ FILE: AzureEnhancedMonitor/ext/handler.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import re import os import subprocess import traceback import time import aem import string from Utils.WAAgentUtil import waagent, InitExtensionEventLog import Utils.HandlerUtil as util ExtensionShortName = 'AzureEnhancedMonitor' ExtensionFullName = 'Microsoft.OSTCExtensions.AzureEnhancedMonitor' ExtensionVersion = 'AzureEnhancedMonitor' def printable(s): return filter(lambda c : c in string.printable, str(s)) def enable(hutil): pidFile = os.path.join(aem.LibDir, "pid"); #Check whether monitor process is running. #If it does, return. Otherwise clear pid file if os.path.isfile(pidFile): pid = waagent.GetFileContents(pidFile) if os.path.isdir(os.path.join("/proc", pid)): if hutil.is_seq_smaller(): hutil.do_exit(0, 'Enable', 'success', '0', 'Azure Enhanced Monitor is already running') else: waagent.Log("Stop old daemon: {0}".format(pid)) os.kill(int(pid), 9) os.remove(pidFile) args = [os.path.join(os.getcwd(), __file__), "daemon"] devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) if child.pid == None or child.pid < 1: hutil.do_exit(1, 'Enable', 'error', '1', 'Failed to launch Azure Enhanced Monitor') else: hutil.save_seq() waagent.SetFileContents(pidFile, str(child.pid)) waagent.Log(("Daemon pid: {0}").format(child.pid)) hutil.do_exit(0, 'Enable', 'success', '0', 'Azure Enhanced Monitor is enabled') def disable(hutil): pidFile = os.path.join(aem.LibDir, "pid"); #Check whether monitor process is running. #If it does, kill it. Otherwise clear pid file if os.path.isfile(pidFile): pid = waagent.GetFileContents(pidFile) if os.path.isdir(os.path.join("/proc", pid)): waagent.Log(("Stop daemon: {0}").format(pid)) os.kill(int(pid), 9) os.remove(pidFile) hutil.do_exit(0, 'Disable', 'success', '0', 'Azure Enhanced Monitor is disabled') os.remove(pidFile) hutil.do_exit(0, 'Disable', 'success', '0', 'Azure Enhanced Monitor is not running') def daemon(hutil): publicConfig = hutil.get_public_settings() privateConfig = hutil.get_protected_settings() config = aem.EnhancedMonitorConfig(publicConfig, privateConfig) monitor = aem.EnhancedMonitor(config) hutil.set_verbose_log(config.isVerbose()) InitExtensionEventLog(hutil.get_name()) while True: waagent.Log("Collecting performance counter.") startTime = time.time() try: monitor.run() message = ("deploymentId={0} roleInstance={1} OK" "").format(config.getVmDeploymentId(), config.getVmRoleInstance()) hutil.do_status_report("Enable", "success", 0, message) except Exception as e: waagent.Error("{0} {1}".format(printable(e), traceback.format_exc())) hutil.do_status_report("Enable", "error", 0, "{0}".format(e)) waagent.Log("Finished collection.") timeElapsed = time.time() - startTime timeToWait = (aem.MonitoringInterval - timeElapsed) #Make sure timeToWait is in the range [0, aem.MonitoringInterval) timeToWait = timeToWait % aem.MonitoringInterval time.sleep(timeToWait) def grace_exit(operation, status, msg): hutil = parse_context(operation) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName, ExtensionFullName, ExtensionVersion) hutil.do_parse_context(operation) return hutil def main(): waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("{0} started to handle.".format(ExtensionShortName)) if not os.path.isdir(aem.LibDir): os.makedirs(aem.LibDir) for command in sys.argv[1:]: if re.match("^([-/]*)(install)", command): grace_exit("install", "success", "Install succeeded") if re.match("^([-/]*)(uninstall)", command): grace_exit("uninstall", "success", "Uninstall succeeded") if re.match("^([-/]*)(update)", command): grace_exit("update", "success", "Update succeeded") try: if re.match("^([-/]*)(enable)", command): hutil = parse_context("enable") enable(hutil) elif re.match("^([-/]*)(disable)", command): hutil = parse_context("disable") disable(hutil) elif re.match("^([-/]*)(daemon)", command): hutil = parse_context("enable") daemon(hutil) except Exception as e: hutil.error("{0}, {1}".format(e, traceback.format_exc())) hutil.do_exit(1, command, 'failed','0', '{0} failed:{1}'.format(command, e)) if __name__ == '__main__': main() ================================================ FILE: AzureEnhancedMonitor/ext/installer.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import imp import os import shutil from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as util ExtensionShortName = 'AzureEnhancedMonitor' def parse_context(operation): hutil = util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName) hutil.do_parse_context(operation) return hutil def find_psutil_build(buildDir): for item in os.listdir(buildDir): try: build = os.path.join(buildDir, item) binary = os.path.join(build, '_psutil_linux.so') imp.load_dynamic('_psutil_linux', binary) return build except Exception: pass raise Exception("Available build of psutil not found.") def main(): waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("{0} started to handle.".format(ExtensionShortName)) hutil = parse_context("Install") try: root = os.path.dirname(os.path.abspath(__file__)) buildDir = os.path.join(root, "libpsutil") build = find_psutil_build(buildDir) for item in os.listdir(build): src = os.path.join(build, item) dest = os.path.join(root, item) if os.path.isfile(src): if os.path.isfile(dest): os.remove(dest) shutil.copyfile(src, dest) else: if os.path.isdir(dest): shutil.rmtree(dest) shutil.copytree(src, dest) except Exception as e: hutil.error("{0}, {1}").format(e, traceback.format_exc()) hutil.do_exit(1, "Install", 'failed','0', 'Install failed: {0}'.format(e)) if __name__ == '__main__': main() ================================================ FILE: AzureEnhancedMonitor/ext/references ================================================ Common/azure-sdk-for-python/azure/ Common/psutil/LICENSE Common/libpsutil Utils/ LICENSE-2_0.txt AzureEnhancedMonitor/hvinfo/bin ================================================ FILE: AzureEnhancedMonitor/ext/test/env.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os test_dir = os.path.dirname(os.path.abspath(__file__)) root = os.path.dirname(test_dir) sys.path.append(root) azure_sdk = os.path.join(root, "Common/azure-sdk-for-python") sys.path.append(azure_sdk) ================================================ FILE: AzureEnhancedMonitor/ext/test/storage_metrics ================================================ [{"TotalRequests": 1, "RowKey": "system;All", "AverageE2ELatency": 52.0, "AverageServerLatency": 48.0, "TotalIngress": 247088, "TotalEgress": 160}, {"TotalRequests": 154, "RowKey": "user;All", "AverageE2ELatency": 6.285714, "AverageServerLatency": 5.551948, "TotalIngress": 1015225, "TotalEgress": 562321}, {"TotalRequests": 6, "RowKey": "user;ClearPage", "AverageE2ELatency": 5.0, "AverageServerLatency": 5.0, "TotalIngress": 3166, "TotalEgress": 1284}, {"TotalRequests": 1, "RowKey": "user;GetBlob", "AverageE2ELatency": 139.0, "AverageServerLatency": 31.0, "TotalIngress": 500, "TotalEgress": 524684}, {"TotalRequests": 11, "RowKey": "user;PutBlob", "AverageE2ELatency": 8.727273, "AverageServerLatency": 8.727273, "TotalIngress": 19026, "TotalEgress": 2475}, {"TotalRequests": 136, "RowKey": "user;PutPage", "AverageE2ELatency": 5.169118, "AverageServerLatency": 5.132353, "TotalIngress": 992533, "TotalEgress": 33878}] ================================================ FILE: AzureEnhancedMonitor/ext/test/test_aem.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import os import json import unittest import env import aem from Utils.WAAgentUtil import waagent TestPublicConfig = """\ { "cfg": [{ "key": "vmsize", "value": "Small (A1)" },{ "key": "vm.roleinstance", "value": "osupdate" },{ "key": "vm.role", "value": "IaaS" },{ "key": "vm.deploymentid", "value": "cd98461b43364478a908d03d0c3135a7" },{ "key": "vm.memory.isovercommitted", "value": 0 },{ "key": "vm.cpu.isovercommitted", "value": 0 },{ "key": "script.version", "value": "1.2.0.0" },{ "key": "verbose", "value": "0" },{ "key": "osdisk.connminute", "value": "asdf.minute" },{ "key": "osdisk.connhour", "value": "asdf.hour" },{ "key": "osdisk.name", "value": "osupdate-osupdate-2015-02-12.vhd" },{ "key": "asdf.hour.uri", "value": "https://asdf.table.core.windows.net/$metricshourprimarytransactionsblob" },{ "key": "asdf.minute.uri", "value": "https://asdf.table.core.windows.net/$metricsminuteprimarytransactionsblob" },{ "key": "asdf.hour.name", "value": "asdf" },{ "key": "asdf.minute.name", "value": "asdf" },{ "key": "wad.name", "value": "asdf" },{ "key": "wad.isenabled", "value": "1" },{ "key": "wad.uri", "value": "https://asdf.table.core.windows.net/wadperformancecounterstable" }] } """ TestPrivateConfig = """\ { "cfg" : [{ "key" : "asdf.minute.key", "value" : "qwer" },{ "key" : "wad.key", "value" : "qwer" }] } """ class TestAEM(unittest.TestCase): def setUp(self): waagent.LoggerInit("/dev/null", "/dev/stdout") def test_config(self): publicConfig = json.loads(TestPublicConfig) privateConfig = json.loads(TestPrivateConfig) config = aem.EnhancedMonitorConfig(publicConfig, privateConfig) self.assertNotEquals(None, config) self.assertEquals(".table.core.windows.net", config.getStorageHostBase('asdf')) self.assertEquals(".table.core.windows.net", config.getLADHostBase()) return config def test_static_datasource(self): config = self.test_config() dataSource = aem.StaticDataSource(config) counters = dataSource.collect() self.assertNotEquals(None, counters) self.assertNotEquals(0, len(counters)) name = "Cloud Provider" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("Microsoft Azure", counter.value) name = "Virtualization Solution Version" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertNotEquals(None, counter.value) name = "Virtualization Solution" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertNotEquals(None, counter.value) name = "Instance Type" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("Small (A1)", counter.value) name = "Data Sources" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("wad", counter.value) name = "Data Provider Version" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("2.0.0", counter.value) name = "Memory Over-Provisioning" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("no", counter.value) name = "CPU Over-Provisioning" counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertEquals("no", counter.value) def test_cpuinfo(self): cpuinfo = aem.CPUInfo.getCPUInfo() self.assertNotEquals(None, cpuinfo) self.assertNotEquals(0, cpuinfo.getNumOfCoresPerCPU()) self.assertNotEquals(0, cpuinfo.getNumOfCores()) self.assertNotEquals(None, cpuinfo.getProcessorType()) self.assertEquals(float, type(cpuinfo.getFrequency())) self.assertEquals(bool, type(cpuinfo.isHyperThreadingOn())) percent = cpuinfo.getCPUPercent() self.assertEquals(float, type(percent)) self.assertTrue(percent >= 0 and percent <= 100) def test_meminfo(self): meminfo = aem.MemoryInfo() self.assertNotEquals(None, meminfo.getMemSize()) self.assertEquals(long, type(meminfo.getMemSize())) percent = meminfo.getMemPercent() self.assertEquals(float, type(percent)) self.assertTrue(percent >= 0 and percent <= 100) def test_networkinfo(self): netinfo = aem.NetworkInfo() adapterIds = netinfo.getAdapterIds() self.assertNotEquals(None, adapterIds) self.assertNotEquals(0, len(adapterIds)) adapterId = adapterIds[0] self.assertNotEquals(None, aem.getMacAddress(adapterId)) self.assertNotEquals(None, netinfo.getNetworkReadBytes()) self.assertNotEquals(None, netinfo.getNetworkWriteBytes()) self.assertNotEquals(None, netinfo.getNetworkPacketRetransmitted()) def test_hwchangeinfo(self): netinfo = aem.NetworkInfo() testHwInfoFile = "/tmp/HwInfo" aem.HwInfoFile = testHwInfoFile if os.path.isfile(testHwInfoFile): os.remove(testHwInfoFile) hwChangeInfo = aem.HardwareChangeInfo(netinfo) self.assertNotEquals(None, hwChangeInfo.getLastHardwareChange()) self.assertTrue(os.path.isfile, aem.HwInfoFile) #No hardware change lastChange = hwChangeInfo.getLastHardwareChange() hwChangeInfo = aem.HardwareChangeInfo(netinfo) self.assertEquals(lastChange, hwChangeInfo.getLastHardwareChange()) #Create mock hardware waagent.SetFileContents(testHwInfoFile, ("0\nma-ca-sa-ds-02")) hwChangeInfo = aem.HardwareChangeInfo(netinfo) self.assertNotEquals(None, hwChangeInfo.getLastHardwareChange()) def test_linux_metric(self): config = self.test_config() metric = aem.LinuxMetric(config) self.validate_cnm_metric(metric) #Metric for CPU, network and memory def validate_cnm_metric(self, metric): self.assertNotEquals(None, metric.getCurrHwFrequency()) self.assertNotEquals(None, metric.getMaxHwFrequency()) self.assertNotEquals(None, metric.getCurrVMProcessingPower()) self.assertNotEquals(None, metric.getGuaranteedMemAssigned()) self.assertNotEquals(None, metric.getMaxVMProcessingPower()) self.assertNotEquals(None, metric.getNumOfCoresPerCPU()) self.assertNotEquals(None, metric.getNumOfThreadsPerCore()) self.assertNotEquals(None, metric.getPhysProcessingPowerPerVCPU()) self.assertNotEquals(None, metric.getProcessorType()) self.assertNotEquals(None, metric.getReferenceComputeUnit()) self.assertNotEquals(None, metric.getVCPUMapping()) self.assertNotEquals(None, metric.getVMProcessingPowerConsumption()) self.assertNotEquals(None, metric.getCurrMemAssigned()) self.assertNotEquals(None, metric.getGuaranteedMemAssigned()) self.assertNotEquals(None, metric.getMaxMemAssigned()) self.assertNotEquals(None, metric.getVMMemConsumption()) adapterIds = metric.getNetworkAdapterIds() self.assertNotEquals(None, adapterIds) self.assertNotEquals(0, len(adapterIds)) adapterId = adapterIds[0] self.assertNotEquals(None, metric.getNetworkAdapterMapping(adapterId)) self.assertNotEquals(None, metric.getMaxNetworkBandwidth(adapterId)) self.assertNotEquals(None, metric.getMinNetworkBandwidth(adapterId)) self.assertNotEquals(None, metric.getNetworkReadBytes()) self.assertNotEquals(None, metric.getNetworkWriteBytes()) self.assertNotEquals(None, metric.getNetworkPacketRetransmitted()) self.assertNotEquals(None, metric.getLastHardwareChange()) def test_vm_datasource(self): config = self.test_config() config.configData["wad.isenabled"] = "0" dataSource = aem.VMDataSource(config) counters = dataSource.collect() self.assertNotEquals(None, counters) self.assertNotEquals(0, len(counters)) counterNames = [ "Current Hw Frequency", "Current VM Processing Power", "Guaranteed VM Processing Power", "Max Hw Frequency", "Max. VM Processing Power", "Number of Cores per CPU", "Number of Threads per Core", "Phys. Processing Power per vCPU", "Processor Type", "Reference Compute Unit", "vCPU Mapping", "VM Processing Power Consumption", "Current Memory assigned", "Guaranteed Memory assigned", "Max Memory assigned", "VM Memory Consumption", "Adapter Id", "Mapping", "Maximum Network Bandwidth", "Minimum Network Bandwidth", "Network Read Bytes", "Network Write Bytes", "Packets Retransmitted" ] #print "\n".join(map(lambda c: str(c), counters)) for name in counterNames: #print name counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertNotEquals(None, counter.value) def test_storagemetric(self): metrics = mock_getStorageMetrics() self.assertNotEquals(None, metrics) stat = aem.AzureStorageStat(metrics) self.assertNotEquals(None, stat.getReadBytes()) self.assertNotEquals(None, stat.getReadOps()) self.assertNotEquals(None, stat.getReadOpE2ELatency()) self.assertNotEquals(None, stat.getReadOpServerLatency()) self.assertNotEquals(None, stat.getReadOpThroughput()) self.assertNotEquals(None, stat.getWriteBytes()) self.assertNotEquals(None, stat.getWriteOps()) self.assertNotEquals(None, stat.getWriteOpE2ELatency()) self.assertNotEquals(None, stat.getWriteOpServerLatency()) self.assertNotEquals(None, stat.getWriteOpThroughput()) def test_disk_info(self): config = self.test_config() mapping = aem.DiskInfo(config).getDiskMapping() self.assertNotEquals(None, mapping) def test_get_storage_key_range(self): startKey, endKey = aem.getStorageTableKeyRange() self.assertNotEquals(None, startKey) self.assertEquals(13, len(startKey)) self.assertNotEquals(None, endKey) self.assertEquals(13, len(endKey)) def test_storage_datasource(self): aem.getStorageMetrics = mock_getStorageMetrics config = self.test_config() dataSource = aem.StorageDataSource(config) counters = dataSource.collect() self.assertNotEquals(None, counters) self.assertNotEquals(0, len(counters)) counterNames = [ "Phys. Disc to Storage Mapping", "Storage ID", "Storage Read Bytes", "Storage Read Op Latency E2E msec", "Storage Read Op Latency Server msec", "Storage Read Ops", "Storage Read Throughput E2E MB/sec", "Storage Write Bytes", "Storage Write Op Latency E2E msec", "Storage Write Op Latency Server msec", "Storage Write Ops", "Storage Write Throughput E2E MB/sec" ] #print "\n".join(map(lambda c: str(c), counters)) for name in counterNames: #print name counter = next((c for c in counters if c.name == name)) self.assertNotEquals(None, counter) self.assertNotEquals(None, counter.value) def test_writer(self): testEventFile = "/tmp/Event" if os.path.isfile(testEventFile): os.remove(testEventFile) writer = aem.PerfCounterWriter() counters = [aem.PerfCounter(counterType = 0, category = "test", name = "test", value = "test", unit = "test")] writer.write(counters, eventFile = testEventFile) with open(testEventFile) as F: content = F.read() self.assertEquals(str(counters[0]), content) testEventFile = "/dev/console" print("==============================") print("The warning below is expected.") self.assertRaises(IOError, writer.write, counters, 2, testEventFile) print("==============================") def test_easyHash(self): hashVal = aem.easyHash('a') self.assertEquals(97, hashVal) hashVal = aem.easyHash('ab') self.assertEquals(87, hashVal) hashVal = aem.easyHash(("ciextension-SUSELinuxEnterpriseServer11SP3" "___role1___" "ciextension-SUSELinuxEnterpriseServer11SP3")) self.assertEquals(5, hashVal) def test_get_ad_key_range(self): startKey, endKey = aem.getAzureDiagnosticKeyRange() print(startKey) print(endKey) def test_get_mds_timestamp(self): date = datetime.datetime(2015, 1, 26, 3, 54) epoch = datetime.datetime.utcfromtimestamp(0) unixTimestamp = (int((date - epoch).total_seconds())) mdsTimestamp = aem.getMDSTimestamp(unixTimestamp) self.assertEquals(635578412400000000, mdsTimestamp) def test_get_storage_timestamp(self): date = datetime.datetime(2015, 1, 26, 3, 54) epoch = datetime.datetime.utcfromtimestamp(0) unixTimestamp = (int((date - epoch).total_seconds())) storageTimestamp = aem.getStorageTimestamp(unixTimestamp) self.assertEquals("20150126T0354", storageTimestamp) def mock_getStorageMetrics(*args, **kwargs): with open(os.path.join(env.test_dir, "storage_metrics")) as F: test_data = F.read() jsonObjs = json.loads(test_data) class ObjectView(object): def __init__(self, data): self.__dict__ = data metrics = map(lambda x : ObjectView(x), jsonObjs) return metrics if __name__ == '__main__': unittest.main() ================================================ FILE: AzureEnhancedMonitor/ext/test/test_installer.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import unittest import env import os import json import datetime import installer class TestInstall(unittest.TestCase): def test_install_psutil(self): buildDir = os.path.join(env.root, "../../Common/libpsutil") build = installer.find_psutil_build(buildDir) self.assertNotEquals(None, build) if __name__ == '__main__': unittest.main() ================================================ FILE: AzureEnhancedMonitor/hvinfo/.gitignore ================================================ bin/* ================================================ FILE: AzureEnhancedMonitor/hvinfo/Makefile ================================================ CC := gcc SRCDIR := src LIBDIR := lib INCDIR := include BUILDDIR := build TARGET := bin/hvinfo SRCEXT := c SOURCES := $(shell find $(SRCDIR) -type f -name *.$(SRCEXT)) OBJECTS := $(patsubst $(SRCDIR)/%,$(BUILDDIR)/%,$(SOURCES:.$(SRCEXT)=.o)) CFLAGS := -g LDFLAGS := INC := -I $(INCDIR) LIB := -L $(LIBDIR) all : $(TARGET) $(TARGET): $(OBJECTS) @echo "Linking..." $(CC) $^ $(LDFLAGS) -o $(TARGET) $(LIB) $(BUILDDIR)/%.o: $(SRCDIR)/%.$(SRCEXT) @mkdir -p $(BUILDDIR) @echo "Compiling..." $(CC) $(CFLAGS) $(INC) -c -o $@ $< clean: @echo "Cleaning..." $(RM) -r $(BUILDDIR) $(TARGET) .PHONY: clean test ================================================ FILE: AzureEnhancedMonitor/hvinfo/src/hvinfo.c ================================================ // // Copyright 2014 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include #include #include void get_cpuid(unsigned int leaf, unsigned int *cpuid) { asm volatile ( "cpuid" : "=a" (cpuid[0]), "=b" (cpuid[1]), "=c" (cpuid[2]), "=d" (cpuid[3]) : "a" (leaf)); } void u32_to_char_arr(char* dest, unsigned int i) { dest[0] = (char)(i & 0xFF); dest[1] = (char)(i >> 8 & 0xFF); dest[2] = (char)(i >> 16 & 0xFF); dest[3] = (char)(i >> 24 & 0xFF); } int main() { unsigned int cpuid[4]; char vendor_id[13]; /* Read hypervisor name*/ memset(cpuid, 0, sizeof(unsigned int) * 4); memset(vendor_id, 0, sizeof(char) * 13); get_cpuid(0x40000000, cpuid); //cpuid[1~3] is hypervisor vendor id signature. //In hyper-v, it is: // // 0x7263694D—“Micr” // 0x666F736F—“osof” // 0x76482074—“t Hv” // u32_to_char_arr(vendor_id, cpuid[1]); u32_to_char_arr(vendor_id + 4, cpuid[2]); u32_to_char_arr(vendor_id + 8, cpuid[3]); printf("%s\n", vendor_id); /* Read hypervisor version*/ memset(cpuid, 0, sizeof(unsigned int) * 4); get_cpuid(0x40000001, cpuid); // cpuid[0] is hypervisor vendor-neutral interface identification. // 0x31237648—“Hv#1. It means the next leaf contains version info. if(0x31237648 != cpuid[0]) { return 1; } memset(cpuid, 0, sizeof(unsigned int) * 4); get_cpuid(0x40000002, cpuid); //cpuid[1] is host version. //The high-end 16 bit is major version, while the low-end is minor. printf("%d.%d\n", (cpuid[1] >> 16) & 0xFF, (cpuid[1]) & 0xFF); return 0; } ================================================ FILE: AzureEnhancedMonitor/nodejs/package.json ================================================ { "name": "azure-linux-tools", "author": "Microsoft Corporation", "contributors": [ "Yue, Zhang " ], "version": "1.0.0", "description": "Azure Linux VM configuration tools", "tags": [ "azure", "vm", "linux", "tools" ], "keywords": [ "node", "azure", "vm", "linux", "tools" ], "main": "setaem.js", "preferGlobal": "true", "engines": { "node": ">= 0.8.26" }, "licenses": [ { "type": "Apache", "url": "http://www.apache.org/licenses/LICENSE-2.0" } ], "dependencies": { "promise" : "6.1.0", "azure-common" : "0.9.13", "azure-storage" : "0.4.2", "azure-arm-storage" : "0.11.0", "azure-arm-compute" : "0.13.0" }, "devDependencies": { }, "homepage": "https://github.com/Azure/azure-linux-extensions", "repository": { "type": "git", "url": "git@github.com:Azure/azure-linux-extensions.git" }, "bin": { "setaem": "setaem.js" }, "scripts":{ } } ================================================ FILE: AzureEnhancedMonitor/nodejs/setaem.js ================================================ #!/usr/bin/env node // // Copyright (c) Microsoft and contributors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // // See the License for the specific language governing permissions and // limitations under the License. // 'use strict'; var fs = require('fs'); var path = require('path'); var Promise = require('promise'); var common = require('azure-common'); var storage = require('azure-storage'); var storageMgmt = require('azure-arm-storage'); var computeMgmt = require('azure-arm-compute'); var readFile = Promise.denodeify(fs.readFile); var debug = 0; /*Const*/ var CurrentScriptVersion = "1.0.0.0"; var aemExtPublisher = "Microsoft.OSTCExtensions"; var aemExtName = "AzureEnhancedMonitorForLinux"; var aemExtVersion = "2.0"; var ladExtName = "LinuxDiagnostic"; var ladExtPublisher = "Microsoft.OSTCExtensions"; var ladExtVersion = "2.0"; var ROLECONTENT = "IaaS"; var AzureEndpoint = "windows.net"; var BlobMetricsMinuteTable= "$MetricsMinutePrimaryTransactionsBlob"; var BlobMetricsHourTable= "$MetricsMinutePrimaryTransactionsBlob"; var ladMetricesTable= ""; /*End of Const*/ var AemConfig = function(){ this.prv = []; this.pub = []; }; AemConfig.prototype.setPublic = function(key, value){ this.pub.push({ 'key' : key, 'value' : value }); }; AemConfig.prototype.setPrivate = function(key, value){ this.prv.push({ 'key' : key, 'value' : value }); }; AemConfig.prototype.getPublic = function(){ return { 'key' : aemExtName + "PublicConfigParameter", 'value' : JSON.stringify({'cfg' : this.pub}), 'type':'Public' } }; AemConfig.prototype.getPrivate = function(){ return { 'key' : aemExtName + "PrivateConfigParameter", 'value' : JSON.stringify({'cfg' : this.prv}), 'type':'Private' } }; var setAzureVMEnhancedMonitorForLinux = function(rgpName, vmName){ var azureProfile; var currSubscription; var computeClient; var storageClient; var selectedVM; var osdiskAccount; var accounts = []; var aemConfig = new AemConfig(); return getAzureProfile().then(function(profile){ azureProfile = profile; return getDefaultSubscription(profile); }).then(function(subscription){ console.log("[INFO]Using subscription: " + subscription.name); debug && console.log(JSON.stringify(subscription, null, 4)); currSubscription = subscription; var cred = getCloudCredential(subscription); var baseUri = subscription.managementEndpointUrl; computeClient = computeMgmt.createComputeManagementClient(cred, baseUri); storageClient = storageMgmt.createStorageManagementClient(cred, baseUri); }).then(function(){ return getVirtualMachine(computeClient, rgpName, vmName); }).then(function(vm){ //Set vm role basic config console.log("[INFO]Found VM: " + vm.oSProfile.computerName); debug && console.log(JSON.stringify(vm, null, 4)); /* vm: { extensions: [ [Object] ], tags: {}, hardwareProfile: { virtualMachineSize: 'Standard_A1' }, storageProfile: { dataDisks: [], imageReference: [Object], oSDisk: [Object] }, oSProfile: { secrets: [], computerName: 'zhongyiubuntu4', adminUsername: 'zhongyi', linuxConfiguration: [Object] }, networkProfile: { networkInterfaces: [Object] }, diagnosticsProfile: { bootDiagnostics: [Object] }, provisioningState: 'Succeeded', id: '/subscriptions/4be8920b-2978-43d7-ab14-04d8549c1d05/resourceGroups/zhongyiubuntu4/providers/Microsoft.Compute/virtualMachines/zhongyiubuntu4', name: 'zhongyiubuntu4', type: 'Microsoft.Compute/virtualMachines', location: 'eastasia' }} */ selectedVM = vm; var cpuOverCommitted = 0; if(selectedVM.hardwareProfile.virtualMachineSize === 'ExtralSmall'){ cpuOverCommitted = 1 } aemConfig.setPublic('vmsize', selectedVM.hardwareProfile.virtualMachineSize); aemConfig.setPublic('vm.role', 'IaaS'); aemConfig.setPublic('vm.memory.isovercommitted', 0); aemConfig.setPublic('vm.cpu.isovercommitted', cpuOverCommitted); aemConfig.setPublic('script.version', CurrentScriptVersion); aemConfig.setPublic('verbose', '0'); aemConfig.setPublic('href', 'http://aka.ms/sapaem'); }).then(function(){ //Set vm disk config /* osDisk: { operatingSystemType: 'Linux', name: 'zhongyiubuntu4', virtualHardDisk: { uri: 'https://zhongyiubuntu44575.blob.core.windows.net/vhds/zhongyiubuntu4.vhd' }, caching: 'ReadWrite', createOption: 'FromImage' } */ var osdisk = selectedVM.storageProfile.oSDisk; osdiskAccount = getStorageAccountFromUri(osdisk.virtualHardDisk.uri); console.log("[INFO]Adding configure for OS disk."); aemConfig.setPublic('osdisk.account', osdiskAccount); aemConfig.setPublic('osdisk.name', osdisk.name); //aemConfig.setPublic('osdisk.caching', osdisk.caching); aemConfig.setPublic('osdisk.connminute', osdiskAccount + ".minute"); aemConfig.setPublic('osdisk.connhour', osdiskAccount + ".hour"); accounts.push({ name: osdiskAccount, }); /* dataDisk: { lun: 0, name: 'zhongyiubuntu4-20151112-140433', virtualHardDisk: { uri: 'https://zhongyiubuntu44575.blob.core.windows.net/vhds/zhongyiubuntu4-20151112-140433.vhd' }, caching: 'None', createOption: 'Empty', diskSizeGB: 1023 } */ for(var i = 0; i < selectedVM.storageProfile.dataDisks.length; i++){ var dataDisk = selectedVM.storageProfile.dataDisks[i]; console.log("[INFO]Adding configure for data disk: " + dataDisk.name); var datadiskAccount = getStorageAccountFromUri(dataDisk.virtualHardDisk.uri); accounts.push({ name:datadiskAccount }); //The default lun value is 0 var lun = dataDisk.lun; aemConfig.setPublic('disk.lun.' + i, lun); aemConfig.setPublic('disk.name.' + i, dataDisk.name); aemConfig.setPublic('disk.caching.' + i, dataDisk.name); aemConfig.setPublic('disk.account.' + i, datadiskAccount); aemConfig.setPublic('disk.connminute.' + i, datadiskAccount + ".minute"); aemConfig.setPublic('disk.connhour.' + i, datadiskAccount + ".hour"); } }).then(function(){ //Set storage account config var promises = []; var i = -2; Object(accounts).forEach(function(account){ var promise = getResourceGroupName(storageClient, account.name) .then(function(rgpName){ account.rgp = rgpName; console.log("!!!!rgp",rgpName); return getStorageAccountKey(storageClient, rgpName, account.name); }).then(function(accountKey){ console.log("!!!!key",accountKey); account.key = accountKey; aemConfig.setPrivate(account.name + ".minute.key", accountKey); aemConfig.setPrivate(account.name + ".hour.key", accountKey); return getStorageAccountProperties(storageClient, account.rgp, account.name); }).then(function(properties){ //ispremium i += 1; if (properties.accountType.startsWith("Standard")) { if (i >= 0) aemConfig.setPublic('disk.type.' + i, "Standard"); else aemConfig.setPublic('osdisk.type' + i, "Standard"); } else { if (i >= 0) aemConfig.setPublic('disk.type.' + i, "Premium"); else aemConfig.setPublic('osdisk.type' + i, "Premium"); aemConfig.setPublic(account.name + ".hour.ispremium", 1); aemConfig.setPublic(account.name + ".minute.ispremium", 1); } //endpoints var endpoints = properties.primaryEndpoints; var tableEndpoint; var blobEndpoint; endpoints.forEach(function(endpoint){ if(endpoint.match(/.*table.*/)){ tableEndpoint = endpoint; }else if(endpoint.match(/.*blob.*/)){ blobEndpoint = endpoint; } }); account.tableEndpoint = tableEndpoint; account.blobEndpoint = blobEndpoint; var minuteUri = tableEndpoint + BlobMetricsMinuteTable; var hourUri = tableEndpoint + BlobMetricsHourTable; account.minuteUri = minuteUri aemConfig.setPublic(account.name + ".hour.uri", hourUri); aemConfig.setsetPrivate(account.name + ".hour.key", account.key); aemConfig.setPublic(account.name + ".minute.uri", minuteUri); aemConfig.setsetPrivate(account.name + ".minute.key", account.key); aemConfig.setPublic(account.name + ".hour.name", account.name); aemConfig.setPublic(account.name + ".minute.name", account.name); }).then(function(){ return checkStorageAccountAnalytics(account.name, account.key, account.blobEndpoint); }); promises.push(promise); }); return Promise.all(promises); }).then(function(res){ //Set Linux diagnostic config aemConfig.setPublic("wad.name", accounts[0].name); aemConfig.setPublic("wad.isenabled", 1); var ladUri = accounts[0].tableEndpoint + ladMetricesTable; console.log("[INFO]Your endpoint is: "+accounts[0].tableEndpoint); aemConfig.setPublic("wad.uri", ladUri); aemConfig.setPrivate("wad.key", accounts[0].key); }).then(function(){ //Update vm var extensions = []; var ladExtConfig = { 'name' : ladExtName, 'referenceName' : ladExtName, 'publisher' : ladExtPublisher, 'version' : ladExtVersion, 'state': 'Enable', 'resourceExtensionParameterValues' : [{ 'key' : ladExtName + "PrivateConfigParameter", 'value' : JSON.stringify({ 'storageAccountName' : accounts[0].name, 'storageAccountKey' : accounts[0].key, 'endpoint' : accounts[0].tableEndpoint.substring((accounts[0].tableEndpoint.search(/\./)) + 1, accounts[0].tableEndpoint.length) }), 'type':'Private' }] }; var aemExtConfig = { 'name' : aemExtName, 'referenceName' : aemExtName, 'publisher' : aemExtPublisher, 'version' : aemExtVersion, 'state': 'Enable', 'resourceExtensionParameterValues' : [ aemConfig.getPublic(), aemConfig.getPrivate() ] }; extensions.push(ladExtConfig); extensions.push(aemExtConfig); selectedVM.provisionGuestAgent = true; selectedVM.resourceExtensionReferences = extensions; console.log("[INFO]Updating configuration for VM: " + selectedVM.roleName); console.log("[INFO]This could take a few minutes. Please wait.") debug && console.log(JSON.stringify(selectedVM, null, 4)) return updateVirtualMachine(computeClient, svcName, vmName, selectedVM); }); } var updateVirtualMachine = function (client, svcName, vmName, parameters){ return new Promise(function(fullfill, reject){ client.virtualMachines.update(svcName, vmName, vmName, parameters, function(err, ret){ if(err){ reject(err) } else { fullfill(ret); } }); }); } var getStorageAccountProperties = function(storageClient, rgpName, accountName){ return new Promise(function(fullfill, reject){ storageClient.storageAccounts.getProperties(rgpName, accountName, function(err, res){ if(err){ reject(err); } else { fullfill(res.storageAccounts.properties); } }); }); }; var getResourceGroupName = function(storageClient, accountName) { return new Promise(function(fullfill, reject){ storageClient.storageAccounts.list(function(err, res){ if(err){ reject(err); } else { res.storageAccounts.forEach(function (storage) { var matchRgp = /resourceGroups\/(.+?)\/.*/.exec(storage.id); var matchAct = /storageAccounts\/(.+?)$/.exec(storage.id); if (matchAct[1] == accountName) { fullfill(matchRgp[1]); } }); } }); }); }; var getStorageAccountKey = function(storageClient, rgpName, accountName){ console.log("123"); return new Promise(function(fullfill, reject){ storageClient.storageAccounts.listKeys(rgpName, accountName, function(err, res){ console.log("??"); if (err) { reject(err); } else { fullfill(res); } }); }); }; var getStorageAccountAnalytics = function(accountName, accountKey, host){ return new Promise(function(fullfill, reject){ var blobService = storage.createBlobService(accountName, accountKey, host); blobService.getServiceProperties(null, function(err, properties, resp){ if(err){ reject(err) } else { fullfill(properties); } }); }); }; var analyticsSettings = { Logging:{ Version: '1.0', Delete: true, Read: true, Write: true, RetentionPolicy: { Enabled: true, Days: 13 } }, HourMetrics:{ Version: '1.0', Enabled: true, IncludeAPIs: true, RetentionPolicy: { Enabled: true, Days: 13 } }, MinuteMetrics:{ Version: '1.0', Enabled: true, IncludeAPIs: true, RetentionPolicy: { Enabled: true, Days: 13 } } }; var checkStorageAccountAnalytics = function(accountName, accountKey, host){ return getStorageAccountAnalytics(accountName, accountKey, host) .then(function(properties){ if(!properties || !properties.Logging || !properties.Logging.Read || !properties.Logging.Write || !properties.Logging.Delete || !properties.MinuteMetrics || !properties.MinuteMetrics.Enabled || !properties.MinuteMetrics.RetentionPolicy || !properties.MinuteMetrics.RetentionPolicy.Enabled || !properties.MinuteMetrics.RetentionPolicy.Days || properties.MinuteMetrics.RetentionPolicy.Days == 0 ){ console.log("[INFO] Turn on storage analytics for: " + accountName) return setStorageAccountAnalytics(accountName, accountKey, host, analyticsSettings); } }); } var setStorageAccountAnalytics = function(accountName, accountKey, host, properties){ return new Promise(function(fullfill, reject){ var blobService = storage.createBlobService(accountName, accountKey, host); blobService.setServiceProperties(properties, null, function(err, properties, resp){ if(err){ reject(err) } else { fullfill(properties); } }); }); }; var getStorageAccountFromUri = function(uri){ var match = /https:\/\/(.+?)\..*/.exec(uri); if(match){ return match[1]; } } var getVirtualMachine = function(computeClient, rgpName, vmName){ return new Promise(function(fullfill, reject){ computeClient.virtualMachines.get(rgpName, vmName, function(err, res){ if(err){ reject(err); } else { fullfill(res.virtualMachine); } }); }); } var getCloudCredential = function(subscription){ var cred; if(subscription.credential.type === 'cert'){ cred = computeMgmt.createCertificateCloudCredentials({ subscriptionId:subscription.id , cert:subscription.managementCertificate.cert, key:subscription.managementCertificate.key, }); }else{//if(subscription.credential.type === 'token'){ cred = new common.TokenCloudCredentials({ subscriptionId : subscription.id, token : subscription.credential.token }); } return cred; } var getAzureProfile = function(){ var profileJSON = path.join(getUserHome(), ".azure/azureProfile.json"); return readFile(profileJSON).then(function(result){ var profile = JSON.parse(result); return profile; }); } var getDefaultSubscription = function(profile){ debug && console.log(JSON.stringify(profile, null, 4)) if(profile == null || profile.subscriptions == null || profile.subscriptions.length == 0){ throw "No subscription found." } console.log("[INFO]Found available subscriptions:"); console.log(""); console.log(" Id\t\t\t\t\t\tName"); console.log(" --------------------------------------------------------"); profile.subscriptions.forEach(function(subscription){ console.log(" " + subscription.id + "\t" + subscription.name); }); console.log(""); var defaultSubscription; profile.subscriptions.every(function(subscription, index, arr){ if(subscription.isDefault){ defaultSubscription = subscription; return false; } else { return true; } }); if(defaultSubscription == null){ console.log("[WARN]No subscription is selected."); defaultSubscription = profile.subscriptions[0]; console.log("[INFO]The first subscription will be used."); console.log("[INFO]You could use the following command to select " + "another subscription."); console.log(""); console.log(" azure account set [|]"); console.log(""); } if(defaultSubscription.user){ return getTokenCredential(defaultSubscription); } else if(defaultSubscription.managementCertificate){ return getCertCredential(defaultSubscription); } else { throw "Unknown subscription type."; } } var getTokenCredential = function(subscription){ var tokensJSON = path.join(getUserHome(), ".azure/accessTokens.json"); return readFile(tokensJSON).then(function(result){ var tokens = JSON.parse(result); tokens.every(function(token, index, arr){ if(token.userId === subscription.user.name){ subscription.credential = { type : 'token', token : token.accessToken }; return false } }); return subscription; }); } var getCertCredential = function(subscription){ subscription.credential = { type : 'cert', cert : subscription.managementCertificate }; return subscription; } function getUserHome() { return process.env[(process.platform == 'win32') ? 'USERPROFILE' : 'HOME']; } var main = function(){ var rgpName = null; var vmName = null; if(process.argv.length === 4){ vmName = process.argv[3]; rgpName = process.argv[2]; } else if(process.argv.length === 3){ if(process.argv[2] === "--help" || process.argv[2] === "-h"){ usage(); process.exit(0); } else if(process.argv[2] === "--version" || process.argv[2] === "-v"){ console.log(CurrentScriptVersion); process.exit(0); } vmName = process.argv[2]; rgpName = vmName; } else{ usage(); process.exit(1); } setAzureVMEnhancedMonitorForLinux(rgpName, vmName).done(function(){ console.log("[INFO]Azure Enhanced Monitoring Extension " + "configuration updated."); console.log("[INFO]It can take up to 15 Minutes for the " + "monitoring data to appear in the system."); process.exit(0); }, function(err){ if(err && err.statusCode == 401){ console.error("[ERROR]Token expired. " + "Please run the following command to login."); console.log(" "); console.log(" azure login"); console.log("or"); console.log(" azure account import "); process.exit(-1); }else{ console.log(err); console.log(err.stack); process.exit(-1); } }); } var usage = function(){ console.log(""); console.log("Usage:"); console.log(" setaem "); console.log("or"); console.log(" setaem "); console.log(""); console.log(" *if service_name and vm_name are the same, " + "service_name could be omitted."); console.log(""); console.log(" "); console.log(" -h, --help "); console.log(" Print help."); console.log(" "); console.log(" -v, --version"); console.log(" Print version."); console.log(" "); } main(); ================================================ FILE: AzureMonitorAgent/.gitignore ================================================ MetricsExtensionBin/ metrics_ext_utils/ packages/ telegraf_utils/ Utils/ waagent ================================================ FILE: AzureMonitorAgent/HandlerManifest.json ================================================ [ { "name": "AzureMonitorLinuxAgent", "version": "1.5.124", "handlerManifest": { "installCommand": "./shim.sh -install", "uninstallCommand": "./shim.sh -uninstall", "updateCommand": "./shim.sh -update", "enableCommand": "./shim.sh -enable", "disableCommand": "./shim.sh -disable", "rebootAfterInstall": false, "reportHeartbeat": false, "updateMode": "UpdateWithInstall", "continueOnUpdateFailure": true }, "resourceLimits": { "services": [ { "name": "azuremonitoragent", "cpuQuotaPercentage": 250 }, { "name": "azuremonitoragentmgr" }, { "name": "azuremonitor-agentlauncher", "cpuQuotaPercentage": 4 }, { "name": "azuremonitor-coreagent", "cpuQuotaPercentage": 200 }, { "name": "metrics-extension", "cpuQuotaPercentage": 5 }, { "name": "metrics-sourcer", "cpuQuotaPercentage": 10 } ] } } ] ================================================ FILE: AzureMonitorAgent/README.md ================================================ # AzureMonitorLinuxAgent Extension Allow the owner of the Azure Virtual Machines to install the Azure Monitor Linux Agent # The Latest Version is 1.6.2 The extension is currently in Public Preview and is accessible to all public cloud regions in Azure. You can read the User Guide below. * [Learn more: Azure Virtual Machine Extensions](https://azure.microsoft.com/en-us/documentation/articles/virtual-machines-extensions-features/) Azure Monitor Linux Agent Extension can: * Install the agent and pull configs from MCS # User Guide ## 1. Deploying the Extension to a VM You can deploy it using Azure CLI ### 1.1. Using Azure CLI Resource Manager You can view the availability of the Azure Monitor Linux Agent extension versions in each region by running: ``` az vm extension image list-versions -l --name AzureMonitorLinuxAgent -p Microsoft.Azure.Monitor ``` You can deploy the Azure Monitor Linux Agent Extension by running: ``` az vm extension set --name AzureMonitorLinuxAgent --publisher Microsoft.Azure.Monitor --version --resource-group --vm-name ``` To update the version of the esisting installation of Azure Monitor Linux Agent extension on a VM, please add "--force-update" flag to the above command. (Currenty Waagent only supports this way of upgrading. Will update once we have more info from them.) ## Supported Linux Distributions Currently Manually tested only on - * CentOS Linux 6, and 7 (x64) * Red Hat Enterprise Linux Server 6 and 7 (x64) * Ubuntu 16.04 LTS, 18.04 LTS(x64) Will Add more distros once they are tested ## Troubleshooting * The status of the extension is reported back to Azure so that user can see the status on Azure Portal * All the extension installation and config files are unzipped into - `/var/lib/waagent/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-/packages/` and the tail of the output is logged into the log directory specified in HandlerEnvironment.json and reported back to Azure * The operation log of the extension is `/var/log/azure/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-/extension.log` file. ================================================ FILE: AzureMonitorAgent/agent.py ================================================ #!/usr/bin/env python # # AzureMonitoringLinuxAgent Extension # # Copyright 2021 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import sys import os import os.path import datetime import signal import pwd import glob import grp import re import filecmp import stat import traceback import time import platform import subprocess import json import base64 import inspect import shutil import hashlib import fileinput import contextlib import ama_tst.modules.install.supported_distros as supported_distros from collections import OrderedDict from hashlib import sha256 from shutil import copyfile, rmtree, copytree, copy2 from threading import Thread import telegraf_utils.telegraf_config_handler as telhandler import metrics_ext_utils.metrics_constants as metrics_constants import metrics_ext_utils.metrics_ext_handler as me_handler import metrics_ext_utils.metrics_common_utils as metrics_utils try: import urllib.request as urllib # Python 3+ except ImportError: import urllib2 as urllib # Python 2 try: from urllib.parse import urlparse # Python 3+ except ImportError: from urlparse import urlparse # Python 2 try: import urllib.error as urlerror # Python 3+ except ImportError: import urllib2 as urlerror # Python 2 # python shim can only make IMDS calls which shouldn't go through proxy try: urllib.getproxies = lambda x = None: {} except Exception as e: print('Resetting proxies failed with error: {0}'.format(e)) try: from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as HUtil except Exception as e: # These utils have checks around the use of them; this is not an exit case print('Importing utils failed with error: {0}'.format(e)) # This code is taken from the omsagent's extension wrapper. # This same monkey patch fix is relevant for AMA extension as well. # This monkey patch duplicates the one made in the waagent import above. # It is necessary because on 2.6, the waagent monkey patch appears to be overridden # by the python-future subprocess.check_output backport. if sys.version_info < (2,7): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output = check_output subprocess.CalledProcessError = CalledProcessError # Global Variables PackagesDirectory = 'packages' # The BundleFileName values will be replaced by actual values in the release pipeline. See apply_version.sh. BundleFileNameDeb = 'azuremonitoragent.deb' BundleFileNameRpm = 'azuremonitoragent.rpm' BundleFileName = '' TelegrafBinName = 'telegraf' InitialRetrySleepSeconds = 30 PackageManager = '' PackageManagerOptions = '' MdsdCounterJsonPath = '/etc/opt/microsoft/azuremonitoragent/config-cache/metricCounters.json' FluentCfgPath = '/etc/opt/microsoft/azuremonitoragent/config-cache/fluentbit/td-agent.conf' AMASyslogConfigMarkerPath = '/etc/opt/microsoft/azuremonitoragent/config-cache/syslog.marker' AMASyslogPortFilePath = '/etc/opt/microsoft/azuremonitoragent/config-cache/syslog.port' AMAFluentPortFilePath = '/etc/opt/microsoft/azuremonitoragent/config-cache/fluent.port' PreviewFeaturesDirectory = '/etc/opt/microsoft/azuremonitoragent/config-cache/previewFeatures/' ArcSettingsFile = '/var/opt/azcmagent/localconfig.json' AMAAstTransformConfigMarkerPath = '/etc/opt/microsoft/azuremonitoragent/config-cache/agenttransform.marker' AMAExtensionLogRotateFilePath = '/etc/logrotate.d/azuremonitoragentextension' WAGuestAgentLogRotateFilePath = '/etc/logrotate.d/waagent-extn.logrotate' AmaUninstallContextFile = '/var/opt/microsoft/uninstall-context' AmaDataPath = '/var/opt/microsoft/azuremonitoragent/' SupportedArch = set(['x86_64', 'aarch64']) MDSDFluentPort = 0 MDSDSyslogPort = 0 # Error codes GenericErrorCode = 1 UnsupportedOperatingSystem = 51 IndeterminateOperatingSystem = 51 MissingorInvalidParameterErrorCode = 53 DPKGOrRPMLockedErrorCode = 56 MissingDependency = 52 # Settings GenevaConfigKey = "genevaConfiguration" AzureMonitorConfigKey = "azureMonitorConfiguration" # Configuration HUtilObject = None SettingsSequenceNumber = None HandlerEnvironment = None SettingsDict = None def main(): """ Main method Parse out operation from argument, invoke the operation, and finish. """ init_waagent_logger() waagent_log_info('Azure Monitoring Agent for Linux started to handle.') # Determine the operation being executed operation = None try: option = sys.argv[1] if re.match('^([-/]*)(disable)', option): operation = 'Disable' elif re.match('^([-/]*)(uninstall)', option): operation = 'Uninstall' elif re.match('^([-/]*)(install)', option): operation = 'Install' elif re.match('^([-/]*)(enable)', option): operation = 'Enable' elif re.match('^([-/]*)(update)', option): operation = 'Update' elif re.match('^([-/]*)(metrics)', option): operation = 'Metrics' elif re.match('^([-/]*)(syslogconfig)', option): operation = 'Syslogconfig' elif re.match('^([-/]*)(transformconfig)', option): operation = 'Transformconfig' except Exception as e: waagent_log_error(str(e)) if operation is None: log_and_exit('Unknown', GenericErrorCode, 'No valid operation provided') # Set up for exit code and any error messages exit_code = 0 message = '{0} succeeded'.format(operation) # Avoid entering broken state where manual purge actions are necessary in low disk space scenario destructive_operations = ['Disable', 'Uninstall'] if operation not in destructive_operations: exit_code = check_disk_space_availability() if exit_code != 0: message = '{0} failed due to low disk space'.format(operation) log_and_exit(operation, exit_code, message) # Invoke operation try: global HUtilObject HUtilObject = parse_context(operation) exit_code, output = operations[operation]() # Exit code 1 indicates a general problem that doesn't have a more # specific error code; it often indicates a missing dependency if exit_code == 1 and operation == 'Install': message = 'Install failed with exit code 1. For error details, check logs ' \ 'in /var/log/azure/Microsoft.Azure.Monitor' \ '.AzureMonitorLinuxAgent' elif exit_code is DPKGOrRPMLockedErrorCode and operation == 'Install': message = 'Install failed with exit code {0} because the ' \ 'package manager on the VM is currently locked: ' \ 'please wait and try again'.format(DPKGOrRPMLockedErrorCode) elif exit_code != 0: message = '{0} failed with exit code {1} {2}'.format(operation, exit_code, output) except AzureMonitorAgentForLinuxException as e: exit_code = e.error_code message = e.get_error_message(operation) except Exception as e: exit_code = GenericErrorCode message = '{0} failed with error: {1}\n' \ 'Stacktrace: {2}'.format(operation, e, traceback.format_exc()) # Finish up and log messages log_and_exit(operation, exit_code, message) def check_disk_space_availability(): """ Check if there is the required space on the machine. """ try: if get_free_space_mb("/var") < 700 or get_free_space_mb("/etc") < 500 or get_free_space_mb("/opt") < 500 : # 52 is the exit code for missing dependency i.e. disk space # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr return MissingDependency else: return 0 except: print('Failed to check disk usage.') return 0 def get_free_space_mb(dirname): """ Get the free space in MB in the directory path. """ st = os.statvfs(dirname) return (st.f_bavail * st.f_frsize) // (1024 * 1024) def is_systemd(): """ Check if the system is using systemd """ return os.path.isdir("/run/systemd/system") def get_service_command(service, *operations): """ Get the appropriate service command [sequence] for the provided service name and operation(s) """ if is_systemd(): return " && ".join(["systemctl {0} {1}".format(operation, service) for operation in operations]) else: hutil_log_info("The VM doesn't have systemctl. Using the init.d service to start {0}.".format(service)) return '/etc/init.d/{0} {1}'.format(service, operations[0]) def check_kill_process(pstring): for line in os.popen("ps ax | grep " + pstring + " | grep -v grep"): fields = line.split() pid = fields[0] os.kill(int(pid), signal.SIGKILL) def compare_and_copy_bin(src, dest): # Check if previous file exist at the location, compare the two binaries, # If the files are not same, remove the older file, and copy the new one # If they are the same, then we ignore it and don't copy if os.path.isfile(src ): if os.path.isfile(dest): if not filecmp.cmp(src, dest): # Removing the file in case it is already being run in a process, # in which case we can get an error "text file busy" while copying os.remove(dest) copyfile(src, dest) else: # No previous binary exist, simply copy it and make it executable copyfile(src, dest) os.chmod(dest, stat.S_IXGRP | stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IXOTH | stat.S_IROTH) def set_metrics_binaries(): current_arch = platform.machine() # Rename the Arch appropriate metrics extension binary to MetricsExtension MetricsExtensionDir = os.path.join(os.getcwd(), 'MetricsExtensionBin') SupportedMEPath = os.path.join(MetricsExtensionDir, 'metricsextension_'+current_arch) if os.path.exists(SupportedMEPath): os.rename(SupportedMEPath, os.path.join(MetricsExtensionDir, 'MetricsExtension')) # Cleanup unused ME binaries for f in os.listdir(MetricsExtensionDir): if f != 'MetricsExtension': os.remove(os.path.join(MetricsExtensionDir, f)) def copy_amacoreagent_binaries(): current_arch = platform.machine() amacoreagent_bin_local_path = os.getcwd() + "/amaCoreAgentBin/amacoreagent_" + current_arch amacoreagent_bin = "/opt/microsoft/azuremonitoragent/bin/amacoreagent" compare_and_copy_bin(amacoreagent_bin_local_path, amacoreagent_bin) if current_arch == 'x86_64': #libgrpc_bin_local_path = os.getcwd() + "/amaCoreAgentBin/libgrpc_csharp_ext.x64.so" #libgrpc_bin = "/opt/microsoft/azuremonitoragent/bin/libgrpc_csharp_ext.x64.so" #compare_and_copy_bin(libgrpc_bin_local_path, libgrpc_bin) liblz4x64_bin_local_path = os.getcwd() + "/amaCoreAgentBin/liblz4x64.so" liblz4x64_bin = "/opt/microsoft/azuremonitoragent/bin/liblz4x64.so" compare_and_copy_bin(liblz4x64_bin_local_path, liblz4x64_bin) #elif current_arch == 'aarch64': #libgrpc_bin_local_path = os.getcwd() + "/amaCoreAgentBin/libgrpc_csharp_ext.arm64.so" #libgrpc_bin = "/opt/microsoft/azuremonitoragent/bin/libgrpc_csharp_ext.arm64.so" #compare_and_copy_bin(libgrpc_bin_local_path, libgrpc_bin) agentlauncher_bin_local_path = os.getcwd() + "/agentLauncherBin/agentlauncher_" + current_arch agentlauncher_bin = "/opt/microsoft/azuremonitoragent/bin/agentlauncher" compare_and_copy_bin(agentlauncher_bin_local_path, agentlauncher_bin) def copy_mdsd_fluentbit_binaries(): current_arch = platform.machine() mdsd_bin_local_path = os.getcwd() + "/mdsdBin/mdsd_" + current_arch mdsdmgr_bin_local_path = os.getcwd() + "/mdsdBin/mdsdmgr_" + current_arch fluentbit_bin_local_path = os.getcwd() + "/fluentBitBin/fluent-bit_" + current_arch mdsd_bin = "/opt/microsoft/azuremonitoragent/bin/mdsd" mdsdmgr_bin = "/opt/microsoft/azuremonitoragent/bin/mdsdmgr" fluentbit_bin = "/opt/microsoft/azuremonitoragent/bin/fluent-bit" # copy the required libs to our test directory first lib_dir = os.path.join(os.getcwd(), "lib") if os.path.exists(lib_dir): rmtree(lib_dir) if sys.version_info >= (3, 8): # dirs_exist_ok parameter was added in Python 3.8 copytree("/opt/microsoft/azuremonitoragent/lib", lib_dir, dirs_exist_ok=True) else: copytree("/opt/microsoft/azuremonitoragent/lib", lib_dir) canUseSharedmdsd, _ = run_command_and_log('ldd ' + mdsd_bin_local_path + ' | grep "not found"') canUseSharedmdsdmgr, _ = run_command_and_log('ldd ' + mdsdmgr_bin_local_path + ' | grep "not found"') if canUseSharedmdsd != 0 and canUseSharedmdsdmgr != 0: compare_and_copy_bin(mdsd_bin_local_path, mdsd_bin) compare_and_copy_bin(mdsdmgr_bin_local_path, mdsdmgr_bin) canUseSharedfluentbit, _ = run_command_and_log('ldd ' + fluentbit_bin_local_path + ' | grep "not found"') if canUseSharedfluentbit != 0: compare_and_copy_bin(fluentbit_bin_local_path, fluentbit_bin) rmtree(os.getcwd() + "/lib") def get_installed_package_version(): """ Returns if Azure Monitor Agent is installed and a list of installed version of the Azure Monitor Agent package. Returns: (is_installed, version_list) """ if PackageManager == "dpkg": # In the case of dpkg, we specify only Package and Version as architecture is written as amd64/arm64 instead of x86_64/aarch64. cmd = "dpkg-query -W -f='${Package}_${Version}\n' 'azuremonitoragent*' 2>/dev/null" elif PackageManager == "rpm": cmd = "rpm -q azuremonitoragent" else: hutil_log_error("Could not determine package manager.") return False, [] exit_code, output = run_command_and_log(cmd, check_error=False) if exit_code != 0 or not output: hutil_log_info("Azure Monitor Agent package not found after running {0}.".format(cmd)) return False, [] version_list = output.strip().split('\n') return True, version_list def get_current_bundle_file(): if PackageManager == 'dpkg': return BundleFileNameDeb.rsplit('.deb', 1)[0] # Remove .deb extension elif PackageManager == 'rpm': return BundleFileNameRpm.rsplit('.rpm', 1)[0] # Remove .rpm extension return "" def install(): """ Ensure that this VM distro and version are supported. Install the Azure Monitor Linux Agent package, using retries. Note: install operation times out from WAAgent at 15 minutes, so do not wait longer. """ exit_if_vm_not_supported('Install') find_package_manager("Install") set_os_arch('Install') vm_dist, vm_ver = find_vm_distro('Install') # Check if Debian 12 and 13 VMs have rsyslog package (required for AMA 1.31+) if (vm_dist.startswith('debian')) and ((vm_ver.startswith('12') or vm_ver.startswith('13')) or int(vm_ver.split('.')[0]) >= 12): check_rsyslog, _ = run_command_and_log("dpkg -s rsyslog") if check_rsyslog != 0: hutil_log_info("'rsyslog' package missing from Debian {0} machine, installing to allow AMA to run.".format(vm_ver)) rsyslog_exit_code, rsyslog_output = run_command_and_log("DEBIAN_FRONTEND=noninteractive apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y rsyslog") if rsyslog_exit_code != 0: return rsyslog_exit_code, rsyslog_output # Check if Amazon 2023 VMs have rsyslog package (required for AMA 1.31+) if (vm_dist.startswith('amzn')) and vm_ver.startswith('2023'): check_rsyslog, _ = run_command_and_log("dnf list installed | grep rsyslog.x86_64") if check_rsyslog != 0: hutil_log_info("'rsyslog' package missing from Amazon Linux 2023 machine, installing to allow AMA to run.") rsyslog_exit_code, rsyslog_output = run_command_and_log("dnf install -y rsyslog") if rsyslog_exit_code != 0: return rsyslog_exit_code, rsyslog_output # Flag to handle the case where the same package is already installed same_package_installed = False # Check if the package is already installed with the correct version is_installed, installed_versions = get_installed_package_version() # Check if the package is already installed, if so determine if it is the same as the bundle or not if is_installed: hutil_log_info("Found installed azuremonitoragent version(s): {0}".format(installed_versions)) # Check if already have this version of AMA installed, if so, no-op for install of AMA if len(installed_versions) == 1: current_bundle = get_current_bundle_file() hutil_log_info("Current bundle file: {0}".format(current_bundle)) package_name = installed_versions[0] # This is to make sure dpkg's package name is in the same format as the BundleFileNameDeb if PackageManager == 'dpkg': architecture = '' if platform.machine() == 'x86_64': architecture = '_x86_64' elif platform.machine() == 'aarch64': architecture = '_aarch64' # need to change the ending from amd64 to x86_64 and arm64 to aarch64 package_name = package_name + architecture if current_bundle == package_name: hutil_log_info("This version of azuremonitoragent package is already installed. Skipping package install.") same_package_installed = True else: hutil_log_error("Multiple versions of azuremonitoragent package found: {0}\n This is undefined behavior, we recommend running the following:".format(installed_versions)) if PackageManager == 'dpkg': hutil_log_error("Run the following command first:\n dpkg --purge azuremonitoragent. If this does not work try the following with caution:\n" "'rm /var/lib/dpkg/info/azuremonitoragent.*' followed by 'dpkg --force-all -P azuremonitoragent'") elif PackageManager == 'rpm': # For reference AzureLinux 3.0 also falls under this category hutil_log_error("Run the following command first: ") hutil_log_error("'rpm -q azuremonitoragent' and for each version run: rpm -e azuremonitoragent-(version)-(bundle_number).(architecture), or rpm -e --deleteall azuremonitoragent\n An example of the command is as follows: rpm -e {0}".format(installed_versions[0])) hutil_log_error("If the following does not work please try the following: rpm -e --noscripts --nodeps azuremonitoragent-(version)-(bundle_number).(architecture). I.e. rpm -e --noscripts --nodeps {0}".format(installed_versions[0])) # If the same bundle of Azure Monitor Agent package is not already installed, proceed with installation if not same_package_installed: hutil_log_info("Installing Azure Monitor Agent package.") package_directory = os.path.join(os.getcwd(), PackagesDirectory) bundle_path = os.path.join(package_directory, BundleFileName) os.chmod(bundle_path, 100) print(PackageManager, " and ", BundleFileName) AMAInstallCommand = "{0} {1} -i {2}".format(PackageManager, PackageManagerOptions, bundle_path) hutil_log_info('Running command "{0}"'.format(AMAInstallCommand)) # Try to install with retry, since install can fail due to concurrent package operations exit_code, output = run_command_with_retries_output(AMAInstallCommand, retries = 15, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) # Retry install for aarch64 rhel8 VMs as initial install fails to create symlink to /etc/systemd/system/azuremonitoragent.service # in /etc/systemd/system/multi-user.target.wants/azuremonitoragent.service if vm_dist.replace(' ','').lower().startswith('redhat') and vm_ver == '8.6' and platform.machine() == 'aarch64': exit_code, output = run_command_with_retries_output(AMAInstallCommand, retries = 15, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) if exit_code != 0: return exit_code, output # System daemon reload is required for systemd to pick up the new service exit_code, output = run_command_and_log("systemctl daemon-reload") if exit_code != 0: return exit_code, output # Copy the AMACoreAgent and agentlauncher binaries copy_amacoreagent_binaries() set_metrics_binaries() # Copy AstExtension binaries # Needs to be revisited for aarch64 copy_astextension_binaries() # Copy mdsd and fluent-bit with OpenSSL dynamically linked if is_feature_enabled('useDynamicSSL'): # Check if they have libssl.so.1.1 since AMA is built against this version libssl1_1, _ = run_command_and_log('ldconfig -p | grep libssl.so.1.1') if libssl1_1 == 0: copy_mdsd_fluentbit_binaries() # Set task limits to max of 65K in suse 12 # Based on Task 9764411: AMA broken after 1.7 in sles 12 - https://dev.azure.com/msazure/One/_workitems/edit/9764411 vm_dist, _ = find_vm_distro('Install') if (vm_dist.startswith('suse') or vm_dist.startswith('sles')): try: suse_exit_code, suse_output = run_command_and_log("mkdir -p /etc/systemd/system/azuremonitoragent.service.d") if suse_exit_code != 0: return suse_exit_code, suse_output suse_exit_code, suse_output = run_command_and_log("echo '[Service]' > /etc/systemd/system/azuremonitoragent.service.d/override.conf") if suse_exit_code != 0: return suse_exit_code, suse_output suse_exit_code, suse_output = run_command_and_log("echo 'TasksMax=65535' >> /etc/systemd/system/azuremonitoragent.service.d/override.conf") if suse_exit_code != 0: return suse_exit_code, suse_output suse_exit_code, suse_output = run_command_and_log("systemctl daemon-reload") if suse_exit_code != 0: return suse_exit_code, suse_output except: log_and_exit("install", MissingorInvalidParameterErrorCode, "Failed to update /etc/systemd/system/azuremonitoragent.service.d for suse 12,15" ) return 0, "Azure Monitor Agent package installed successfully" def uninstall(): """ Uninstall the Azure Monitor Linux Agent. Whether it is a purge of all files or preserve of log files depends on the uninstall context file. Note: uninstall operation times out from WAAgent at 5 minutes """ exit_if_vm_not_supported('Uninstall') find_package_manager("Uninstall") # Before we uninstall, we need to ensure AMA is installed to begin with is_installed, installed_versions = get_installed_package_version() if not is_installed: hutil_log_info("Azure Monitor Agent is not installed, nothing to uninstall.") return 0, "Azure Monitor Agent is not installed, nothing to uninstall." if PackageManager != "dpkg" and PackageManager != "rpm": log_and_exit("Uninstall", UnsupportedOperatingSystem, "The OS has neither rpm nor dpkg." ) # For clean uninstall, gather the file list BEFORE running the uninstall command # This ensures we have the complete list even after the package manager removes its database package_files_for_cleanup = [] hutil_log_info("Gathering package file list for clean uninstall before removing package") package_files_for_cleanup = _get_package_files_for_cleanup() # Attempt to uninstall each specific # Try a specific package uninstall for rpm if PackageManager == "rpm": purge_cmd_template = "rpm -e {0}" # Process each package for package_name in installed_versions: if not package_name.strip(): continue package_name = package_name.strip() # Clean the package name and create uninstall command uninstall_command = purge_cmd_template.format(package_name) hutil_log_info("Removing package: {0} by running {1}".format(package_name, uninstall_command)) # Execute uninstall command with retries exit_code, output = run_command_with_retries_output( uninstall_command, retries=4, retry_check=retry_if_dpkg_or_rpm_locked, final_check=final_check_if_dpkg_or_rpm_locked ) elif PackageManager == "dpkg": AMAUninstallCommand = "dpkg -P azuremonitoragent" hutil_log_info("Removing package: azuremonitoragent by running {0}".format(AMAUninstallCommand)) exit_code, output = run_command_with_retries_output( AMAUninstallCommand, retries=4, retry_check=retry_if_dpkg_or_rpm_locked, final_check=final_check_if_dpkg_or_rpm_locked ) remove_localsyslog_configs() uninstall_azureotelcollector() # remove the logrotate config if os.path.exists(AMAExtensionLogRotateFilePath): try: os.remove(AMAExtensionLogRotateFilePath) except Exception as ex: output = 'Logrotate removal failed with error: {0}\n' \ 'Stacktrace: {1}'.format(ex, traceback.format_exc()) hutil_log_info(output) # Retry, since uninstall can fail due to concurrent package operations try: exit_code, output = force_uninstall_azure_monitor_agent() # Remove all files installed by the package that were listed _remove_package_files_from_list(package_files_for_cleanup) # Clean up context marker (always do this) _cleanup_uninstall_context() except Exception as ex: exit_code = GenericErrorCode output = 'Uninstall failed with error: {0}\n' \ 'Stacktrace: {1}'.format(ex, traceback.format_exc()) return exit_code, output def force_uninstall_azure_monitor_agent(): """ Force uninstall the Azure Monitor Linux Agent package with possibility of multiple existing Azure Monitor Agent Linux packages. Just for rpm ,this function will attempt to uninstall each package in the installed_versions list. If it still persists, a force uninstall is done. Returns: (exit_code, output_message or installed_versions (list of remaining packages)) """ # Check if azuremonitoragent is still installed, exit code will be non-zero if it is not. is_installed, remaining_packages = get_installed_package_version() commands_used = [] if is_installed: # Since the previous uninstall failed we are going down the route of uninstall without dep and pre/post hutil_log_info("Initial uninstall command did not remove all packages. Remaining packages: {0}".format(remaining_packages)) AMAUninstallCommandForce = "" if PackageManager == "dpkg": # we can remove the post and pre scripts first then purge RemoveScriptsCommand = "rm /var/lib/dpkg/info/azuremonitoragent.*" run_command_with_retries_output(RemoveScriptsCommand, retries = 4, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) AMAUninstallCommandForce = "dpkg --force-all -P azuremonitoragent" hutil_log_info('Running command "{0}"'.format(AMAUninstallCommandForce)) exit_code, output = run_command_with_retries_output(AMAUninstallCommandForce, retries = 4, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) commands_used.extend([RemoveScriptsCommand, AMAUninstallCommandForce]) elif PackageManager == "rpm": # First try to mass uninstall AMA by using the --allmatches flag for rpm # This is a more robust version of uninstall() since it uses the --allmatches flag AMAUninstallCommand = "rpm -e --allmatches azuremonitoragent" hutil_log_info('Running command "{0}"'.format(AMAUninstallCommand)) exit_code, output = run_command_with_retries_output(AMAUninstallCommand, retries = 4, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) hutil_log_info("Force uninstall command {0} returned exit code {1} and output: {2}".format(AMAUninstallCommandForce, exit_code, output)) commands_used.append(AMAUninstallCommand) # Query to see what is left after using the --allmatches uninstall is_still_installed, remaining_packages = get_installed_package_version() # If the above command fails, we will try to force uninstall each package by using the --noscripts and --nodeps flags if is_still_installed: hutil_log_info("Failed to uninstall azuremonitoragent with --allmatches, trying to force uninstall each package individually.") # --noscripts and --nodeps flags are used to avoid running any pre/post scripts and skip dependencies test # https://jfearn.fedorapeople.org/en-US/RPM/4/html/RPM_Guide/ch03s03s03.html for package in remaining_packages: # Clean the package name and create uninstall command package = package.strip() if not package: continue AMAUninstallCommandForce = "rpm -e --noscripts --nodeps {0}".format(package) commands_used.append(AMAUninstallCommandForce) hutil_log_info('Running command "{0}"'.format(AMAUninstallCommandForce)) exit_code, output = run_command_with_retries_output(AMAUninstallCommandForce, retries = 4, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked) hutil_log_info("Force uninstall command {0} returned exit code {1} and output: {2}".format(AMAUninstallCommandForce, exit_code, output)) # Check if packages are still installed is_still_installed, remaining_packages = get_installed_package_version() if is_still_installed: output = "Force uninstall did not remove all packages, remaining packages: {0}".format(remaining_packages) hutil_log_info("Force uninstall did not remove all packages, remaining packages: {0}".format(remaining_packages)) return 1, output else: hutil_log_info("Force uninstall removed all packages successfully after using: {0}".format(", ".join(commands_used))) return 0, "Azure Monitor Agent packages uninstalled successfully after using: {0}".format(", ".join(commands_used)) # Since there was no indication of AMA, we can assume it was uninstalled successfully else: hutil_log_info("Azure Monitor Agent has been uninstalled.") return 0, "Azure Monitor Agent has been uninstalled." def _get_package_files_for_cleanup(): """ Get the list of files and directories installed by the provided azuremonitoragent spec that should be removed during uninstall. This must be called BEFORE the package is uninstalled to ensure the package manager still has the file list available. Returns: tuple: (files_list, directories_to_add) where files_list contains package files and directories_to_add contains directories that need explicit cleanup """ try: # Get list of files installed by the package if PackageManager == "dpkg": # For Debian-based systems cmd = "dpkg -L azuremonitoragent" elif PackageManager == "rpm": # For RPM-based systems cmd = "rpm -ql azuremonitoragent" else: hutil_log_info("Unknown package manager, cannot list package files") return [] exit_code, output = run_command_and_log(cmd, check_error=False) if exit_code != 0 or not output: hutil_log_info("Could not get package file list for cleanup") return [] # Parse the file list files = [line.strip() for line in output.strip().split('\n') if line.strip()] # Collect all azuremonitor-related paths azuremonitoragent_files = [] for file_path in files: # Only include files/directories that have "azuremonitor" in their path # This covers both "azuremonitoragent" and "azuremonitor-*" service files if "azuremonitor" in file_path: azuremonitoragent_files.append(file_path) else: hutil_log_info("Skipping non-azuremonitor path: {0}".format(file_path)) return azuremonitoragent_files except Exception as ex: hutil_log_error("Error gathering package files for cleanup: {0}\n Is Azure Monitor Agent Installed?".format(ex)) return [] def _remove_package_files_from_list(package_files): """ Remove all files and directories from the provided list that were installed by the provided azuremonitoragent spec. This function works with a pre-gathered list of files from _get_package_files_for_cleanup(), allowing it to work even after the package has been uninstalled. Args: package_files (list): List of file/directory paths to remove """ try: if not package_files: hutil_log_info("No package files provided for removal") return # Build consolidated list of paths to clean up cleanup_paths = set(package_files) if package_files else set() # Add directories that need explicit cleanup since on rpm systems # the initial list for this path does not remove the directories and files cleanup_paths.add("/opt/microsoft/azuremonitoragent/") # Determine uninstall context based on if the context file exists uninstall_context = _get_uninstall_context() hutil_log_info("Uninstall context: {0}".format(uninstall_context)) if uninstall_context == 'complete': hutil_log_info("Complete uninstall context - removing everything") cleanup_paths.add(AmaDataPath) # Sort paths by depth (deepest first) to avoid removing parent before children sorted_paths = sorted(cleanup_paths, key=lambda x: x.count('/'), reverse=True) hutil_log_info("Removing {0} azuremonitor paths".format(len(sorted_paths))) items_removed = 0 for item_path in sorted_paths: try: if os.path.exists(item_path): if os.path.isdir(item_path): rmtree(item_path) hutil_log_info("Removed directory: {0}".format(item_path)) else: os.remove(item_path) hutil_log_info("Removed file: {0}".format(item_path)) items_removed += 1 except Exception as ex: hutil_log_info("Failed to remove {0}: {1}".format(item_path, ex)) hutil_log_info("Removed {0} items total".format(items_removed)) except Exception as ex: hutil_log_error("Error during file removal from list: {0}\n Were these files removed already?".format(ex)) def enable(): """ Start the Azure Monitor Linux Agent Service This call will return non-zero or throw an exception if the settings provided are incomplete or incorrect. Note: enable operation times out from WAAgent at 5 minutes """ public_settings, protected_settings = get_settings() exit_if_vm_not_supported('Enable') ensure = OrderedDict([ ("azuremonitoragent", False), ("azuremonitoragentmgr", False) ]) # Set traceFlags in publicSettings to enable mdsd tracing. For example, the EventIngest flag can be enabled via "traceFlags": "0x2" flags = "" if public_settings is not None and "traceFlags" in public_settings: flags = "-T {} ".format(public_settings.get("traceFlags")) # Use an Ordered Dictionary to ensure MDSD_OPTIONS (and other dependent variables) are written after their dependencies default_configs = OrderedDict([ ("MDSD_CONFIG_DIR", "/etc/opt/microsoft/azuremonitoragent"), ("MDSD_LOG_DIR", "/var/opt/microsoft/azuremonitoragent/log"), ("MDSD_ROLE_PREFIX", "/run/azuremonitoragent/default"), ("MDSD_SPOOL_DIRECTORY", "/var/opt/microsoft/azuremonitoragent"), ("MDSD_OPTIONS", "\"{}-A -R -c /etc/opt/microsoft/azuremonitoragent/mdsd.xml -d -r $MDSD_ROLE_PREFIX -S $MDSD_SPOOL_DIRECTORY/eh -L $MDSD_SPOOL_DIRECTORY/events\"".format(flags)), ("MDSD_USE_LOCAL_PERSISTENCY", "true"), ("MDSD_TCMALLOC_RELEASE_FREQ_SEC", "1"), ("MONITORING_USE_GENEVA_CONFIG_SERVICE", "false"), ("ENABLE_MCS", "false") ]) ssl_cert_var_name, ssl_cert_var_value = get_ssl_cert_info('Enable') default_configs[ssl_cert_var_name] = ssl_cert_var_value """ Decide the mode and configuration. There are two supported configuration schema, mix-and-match between schemas is disallowed: Legacy: allows one of [MCS, GCS single tenant, or GCS multi tenant ("Auto-Config")] modes Next-Generation: allows MCS, GCS multi tenant, or both """ is_gcs_single_tenant = False GcsEnabled, McsEnabled = get_control_plane_mode() # Next-generation schema if public_settings is not None and (public_settings.get(GenevaConfigKey) or public_settings.get(AzureMonitorConfigKey)): geneva_configuration = public_settings.get(GenevaConfigKey) azure_monitor_configuration = public_settings.get(AzureMonitorConfigKey) # Check for mix-and match of next-generation and legacy schema content if len(public_settings) > 1 and ((geneva_configuration and not azure_monitor_configuration) or (azure_monitor_configuration and not geneva_configuration)): log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Mixing genevaConfiguration or azureMonitorConfiguration with other configuration schemas is not allowed') if geneva_configuration and geneva_configuration.get("enable") == True: hutil_log_info("Detected Geneva+ mode; azuremonitoragentmgr service will be started to handle Geneva tenants") ensure["azuremonitoragentmgr"] = True if azure_monitor_configuration and azure_monitor_configuration.get("enable") == True: hutil_log_info("Detected Azure Monitor+ mode; azuremonitoragent service will be started to handle Azure Monitor tenant") ensure["azuremonitoragent"] = True azure_monitor_public_settings = azure_monitor_configuration.get("configuration") azure_monitor_protected_settings = protected_settings.get(AzureMonitorConfigKey) if protected_settings is not None else None handle_mcs_config(azure_monitor_public_settings, azure_monitor_protected_settings, default_configs) # Legacy schema elif public_settings is not None and public_settings.get("GCS_AUTO_CONFIG") == True: hutil_log_info("Detected Auto-Config mode; azuremonitoragentmgr service will be started to handle Geneva tenants") ensure["azuremonitoragentmgr"] = True elif (protected_settings is None or len(protected_settings) == 0) or (public_settings is not None and "proxy" in public_settings and "mode" in public_settings.get("proxy") and public_settings.get("proxy").get("mode") == "application"): hutil_log_info("Detected Azure Monitor mode; azuremonitoragent service will be started to handle Azure Monitor configuration") ensure["azuremonitoragent"] = True handle_mcs_config(public_settings, protected_settings, default_configs) else: hutil_log_info("Detected Geneva mode; azuremonitoragent service will be started to handle Geneva configuration") ensure["azuremonitoragent"] = True is_gcs_single_tenant = True handle_gcs_config(public_settings, protected_settings, default_configs) # generate local syslog configuration files as in auto config syslog is not driven from DCR # Note that internally AMCS with geneva config path can be used in which case syslog should be handled same way as default 1P # generate local syslog configuration files as in 1P syslog is not driven from DCR if GcsEnabled: generate_localsyslog_configs(uses_gcs=True, uses_mcs=McsEnabled) config_file = "/etc/default/azuremonitoragent" temp_config_file = "/etc/default/azuremonitoragent_temp" try: if os.path.isfile(config_file): new_config = "\n".join(["export {0}={1}".format(key, value) for key, value in default_configs.items()]) + "\n" with open(temp_config_file, "w") as f: f.write(new_config) if not os.path.isfile(temp_config_file): log_and_exit("Enable", GenericErrorCode, "Error while updating environment variables in {0}".format(config_file)) os.remove(config_file) os.rename(temp_config_file, config_file) else: log_and_exit("Enable", GenericErrorCode, "Could not find the file {0}".format(config_file)) except Exception as e: log_and_exit("Enable", GenericErrorCode, "Failed to add environment variables to {0}: {1}".format(config_file, e)) if "ENABLE_MCS" in default_configs and default_configs["ENABLE_MCS"] == "true": # enable processes for Custom Logs ensure["azuremonitor-agentlauncher"] = True ensure["azuremonitor-coreagent"] = True # start the metrics, agent transform and syslog watchers only in 3P mode start_metrics_process() start_syslogconfig_process() elif ensure.get("azuremonitoragentmgr") or is_gcs_single_tenant: # In GCS scenarios, ensure that AMACoreAgent is running ensure["azuremonitor-coreagent"] = True hutil_log_info('Handler initiating onboarding.') if HUtilObject and HUtilObject.is_seq_smaller(): # Either upgrade has just happened (in which case we need to start), or enable was called with no change to extension config hutil_log_info("Current sequence number, " + HUtilObject._context._seq_no + ", is not greater than the LKG sequence number. Starting service(s) only if it is not yet running.") operations = ["start", "enable"] else: # Either this is a clean install (in which case restart is effectively start), or extension config has changed hutil_log_info("Current sequence number, " + HUtilObject._context._seq_no + ", is greater than the LKG sequence number. Restarting service(s) to pick up the new config.") operations = ["restart", "enable"] output = "" # Ensure non-required services are not running; do not block if this step fails for service in [s for s in ensure.keys() if not ensure[s]]: exit_code, disable_output = run_command_and_log(get_service_command(service, "stop", "disable")) output += disable_output for service in [s for s in ensure.keys() if ensure[s]]: exit_code, enable_output = run_command_and_log(get_service_command(service, *operations)) output += enable_output if exit_code != 0: status_command = get_service_command(service, "status") status_exit_code, status_output = run_command_and_log(status_command) if status_exit_code != 0: output += "Output of '{0}':\n{1}".format(status_command, status_output) return exit_code, output if platform.machine() != 'aarch64': if "ENABLE_MCS" in default_configs and default_configs["ENABLE_MCS"] == "true": # start/enable ast extension only in 3P mode and non aarch64 _, ast_output = run_command_and_log(get_service_command("azuremonitor-astextension", *operations)) output += ast_output # do not block if ast start fails # start transformation config watcher process start_transformconfig_process() # Service(s) were successfully configured and started; increment sequence number HUtilObject.save_seq() return exit_code, output def handle_gcs_config(public_settings, protected_settings, default_configs): """ Populate the defaults for legacy-path GCS mode """ # look for LA protected settings for var in list(protected_settings.keys()): if "_key" in var or "_id" in var: default_configs[var] = protected_settings.get(var) # check if required GCS params are available MONITORING_GCS_CERT_CERTFILE = None if "certificate" in protected_settings: MONITORING_GCS_CERT_CERTFILE = base64.standard_b64decode(protected_settings.get("certificate")) if "certificatePath" in protected_settings: try: with open(protected_settings.get("certificatePath"), 'r') as f: MONITORING_GCS_CERT_CERTFILE = f.read() except Exception as ex: log_and_exit('Enable', MissingorInvalidParameterErrorCode, 'Failed to read certificate {0}: {1}'.format(protected_settings.get("certificatePath"), ex)) MONITORING_GCS_CERT_KEYFILE = None if "certificateKey" in protected_settings: MONITORING_GCS_CERT_KEYFILE = base64.standard_b64decode(protected_settings.get("certificateKey")) if "certificateKeyPath" in protected_settings: try: with open(protected_settings.get("certificateKeyPath"), 'r') as f: MONITORING_GCS_CERT_KEYFILE = f.read() except Exception as ex: log_and_exit('Enable', MissingorInvalidParameterErrorCode, 'Failed to read certificate key {0}: {1}'.format(protected_settings.get("certificateKeyPath"), ex)) MONITORING_GCS_ENVIRONMENT = "" if "monitoringGCSEnvironment" in protected_settings: MONITORING_GCS_ENVIRONMENT = protected_settings.get("monitoringGCSEnvironment") MONITORING_GCS_NAMESPACE = "" if "namespace" in protected_settings: MONITORING_GCS_NAMESPACE = protected_settings.get("namespace") MONITORING_GCS_ACCOUNT = "" if "monitoringGCSAccount" in protected_settings: MONITORING_GCS_ACCOUNT = protected_settings.get("monitoringGCSAccount") MONITORING_GCS_REGION = "" if "monitoringGCSRegion" in protected_settings: MONITORING_GCS_REGION = protected_settings.get("monitoringGCSRegion") MONITORING_CONFIG_VERSION = "" if "configVersion" in protected_settings: MONITORING_CONFIG_VERSION = protected_settings.get("configVersion") MONITORING_GCS_AUTH_ID_TYPE = "" if "monitoringGCSAuthIdType" in protected_settings: MONITORING_GCS_AUTH_ID_TYPE = protected_settings.get("monitoringGCSAuthIdType") MONITORING_GCS_AUTH_ID = "" if "monitoringGCSAuthId" in protected_settings: MONITORING_GCS_AUTH_ID = protected_settings.get("monitoringGCSAuthId") MONITORING_TENANT = "" if "monitoringTenant" in protected_settings: MONITORING_TENANT = protected_settings.get("monitoringTenant") MONITORING_ROLE = "" if "monitoringRole" in protected_settings: MONITORING_ROLE = protected_settings.get("monitoringRole") MONITORING_ROLE_INSTANCE = "" if "monitoringRoleInstance" in protected_settings: MONITORING_ROLE_INSTANCE = protected_settings.get("monitoringRoleInstance") if ((MONITORING_GCS_CERT_CERTFILE is None or MONITORING_GCS_CERT_KEYFILE is None) and (MONITORING_GCS_AUTH_ID_TYPE == "")) or MONITORING_GCS_ENVIRONMENT == "" or MONITORING_GCS_NAMESPACE == "" or MONITORING_GCS_ACCOUNT == "" or MONITORING_GCS_REGION == "" or MONITORING_CONFIG_VERSION == "": log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Not all required GCS parameters are provided') else: # set the values for GCS default_configs["MONITORING_USE_GENEVA_CONFIG_SERVICE"] = "true" default_configs["MONITORING_GCS_ENVIRONMENT"] = MONITORING_GCS_ENVIRONMENT default_configs["MONITORING_GCS_NAMESPACE"] = MONITORING_GCS_NAMESPACE default_configs["MONITORING_GCS_ACCOUNT"] = MONITORING_GCS_ACCOUNT default_configs["MONITORING_GCS_REGION"] = MONITORING_GCS_REGION default_configs["MONITORING_CONFIG_VERSION"] = MONITORING_CONFIG_VERSION # write the certificate and key to disk uid = pwd.getpwnam("syslog").pw_uid gid = grp.getgrnam("syslog").gr_gid if MONITORING_GCS_AUTH_ID_TYPE != "": default_configs["MONITORING_GCS_AUTH_ID_TYPE"] = MONITORING_GCS_AUTH_ID_TYPE if MONITORING_GCS_AUTH_ID != "": default_configs["MONITORING_GCS_AUTH_ID"] = MONITORING_GCS_AUTH_ID if MONITORING_GCS_CERT_CERTFILE is not None: default_configs["MONITORING_GCS_CERT_CERTFILE"] = "/etc/opt/microsoft/azuremonitoragent/gcscert.pem" with open("/etc/opt/microsoft/azuremonitoragent/gcscert.pem", "wb") as f: f.write(MONITORING_GCS_CERT_CERTFILE) os.chown("/etc/opt/microsoft/azuremonitoragent/gcscert.pem", uid, gid) os.system('chmod {1} {0}'.format("/etc/opt/microsoft/azuremonitoragent/gcscert.pem", 400)) if MONITORING_GCS_CERT_KEYFILE is not None: default_configs["MONITORING_GCS_CERT_KEYFILE"] = "/etc/opt/microsoft/azuremonitoragent/gcskey.pem" with open("/etc/opt/microsoft/azuremonitoragent/gcskey.pem", "wb") as f: f.write(MONITORING_GCS_CERT_KEYFILE) os.chown("/etc/opt/microsoft/azuremonitoragent/gcskey.pem", uid, gid) os.system('chmod {1} {0}'.format("/etc/opt/microsoft/azuremonitoragent/gcskey.pem", 400)) if MONITORING_TENANT != "": default_configs["MONITORING_TENANT"] = MONITORING_TENANT if MONITORING_ROLE != "": default_configs["MONITORING_ROLE"] = MONITORING_ROLE if MONITORING_TENANT != "": default_configs["MONITORING_ROLE_INSTANCE"] = MONITORING_ROLE_INSTANCE def handle_mcs_config(public_settings, protected_settings, default_configs): """ Populate the defaults for MCS mode """ default_configs["ENABLE_MCS"] = "true" default_configs["PA_GIG_BRIDGE_MODE"] = "true" # April 2022: PA_FLUENT_SOCKET_PORT setting is being deprecated in place of PA_DATA_PORT. Remove when AMA 1.17 and earlier no longer need servicing. default_configs["PA_FLUENT_SOCKET_PORT"] = "13005" # this port will be dynamic in future default_configs["PA_DATA_PORT"] = "13005" proxySet = False # fetch proxy settings if public_settings is not None and "proxy" in public_settings and "mode" in public_settings.get("proxy") and public_settings.get("proxy").get("mode") == "application": default_configs["MDSD_PROXY_MODE"] = "application" if "address" in public_settings.get("proxy"): default_configs["MDSD_PROXY_ADDRESS"] = public_settings.get("proxy").get("address") else: log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Parameter "address" is required in proxy public setting') if "auth" in public_settings.get("proxy") and public_settings.get("proxy").get("auth") == True: if protected_settings is not None and "proxy" in protected_settings and "username" in protected_settings.get("proxy") and "password" in protected_settings.get("proxy"): default_configs["MDSD_PROXY_USERNAME"] = protected_settings.get("proxy").get("username") default_configs["MDSD_PROXY_PASSWORD"] = protected_settings.get("proxy").get("password") set_proxy(default_configs["MDSD_PROXY_ADDRESS"], default_configs["MDSD_PROXY_USERNAME"], default_configs["MDSD_PROXY_PASSWORD"]) proxySet = True else: log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Parameter "username" and "password" not in proxy protected setting') else: set_proxy(default_configs["MDSD_PROXY_ADDRESS"], "", "") proxySet = True # is this Arc? If so, check for proxy if os.path.isfile(ArcSettingsFile): f = open(ArcSettingsFile, "r") data = f.read() if (data != ''): json_data = json.loads(data) BypassProxy = False if json_data is not None and "proxy.bypass" in json_data: bypass = json_data["proxy.bypass"] # proxy.bypass is an array if "AMA" in bypass: BypassProxy = True if not BypassProxy and json_data is not None and "proxy.url" in json_data: url = json_data["proxy.url"] # only non-authenticated proxy config is supported if url != '': default_configs["MDSD_PROXY_ADDRESS"] = url set_proxy(default_configs["MDSD_PROXY_ADDRESS"], "", "") proxySet = True if not proxySet: unset_proxy() # set arc autonomous endpoints az_environment, _ = get_azure_environment_and_region() if az_environment == me_handler.ArcACloudName: try: _, mcs_endpoint = me_handler.get_arca_endpoints_from_himds() except Exception as ex: log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Failed to get Arc autonomous endpoints. {0}'.format(ex)) default_configs["customRegionalEndpoint"] = mcs_endpoint default_configs["customGlobalEndpoint"] = mcs_endpoint default_configs["customResourceEndpoint"] = "https://monitoring.azs" # add managed identity settings if they were provided identifier_name, identifier_value, error_msg = get_managed_identity() if error_msg: log_and_exit("Enable", MissingorInvalidParameterErrorCode, 'Failed to determine managed identity settings. {0}.'.format(error_msg)) if identifier_name and identifier_value: default_configs["MANAGED_IDENTITY"] = "{0}#{1}".format(identifier_name, identifier_value) def get_control_plane_mode(): """ Identify which control plane is in use """ public_settings, protected_settings = get_settings() GcsEnabled = False McsEnabled = False if public_settings is not None and (public_settings.get(GenevaConfigKey) or public_settings.get(AzureMonitorConfigKey)): geneva_configuration = public_settings.get(GenevaConfigKey) azure_monitor_configuration = public_settings.get(AzureMonitorConfigKey) if geneva_configuration and geneva_configuration.get("enable") == True: GcsEnabled = True if azure_monitor_configuration and azure_monitor_configuration.get("enable") == True: McsEnabled = True # Legacy schema elif public_settings is not None and public_settings.get("GCS_AUTO_CONFIG") == True: GcsEnabled = True elif (protected_settings is None or len(protected_settings) == 0) or (public_settings is not None and "proxy" in public_settings and "mode" in public_settings.get("proxy") and public_settings.get("proxy").get("mode") == "application"): McsEnabled = True else: GcsEnabled = True return GcsEnabled, McsEnabled def disable(): """ Disable Azure Monitor Linux Agent process on the VM. Note: disable operation times out from WAAgent at 15 minutes """ #stop the metrics process stop_metrics_process() #stop syslog config watcher process stop_syslogconfig_process() #stop agent transform config watcher process stop_transformconfig_process() # stop amacoreagent and agent launcher hutil_log_info('Handler initiating Core Agent and agent launcher') if is_systemd(): exit_code, output = run_command_and_log('systemctl stop azuremonitor-coreagent && systemctl disable azuremonitor-coreagent') exit_code, output = run_command_and_log('systemctl stop azuremonitor-agentlauncher && systemctl disable azuremonitor-agentlauncher') # in case AL is not cleaning up properly check_kill_process('/opt/microsoft/azuremonitoragent/bin/fluent-bit') # Stop and disable systemd services so they are not started after system reboot. for service in ["azuremonitoragent", "azuremonitoragentmgr"]: exit_code, output = run_command_and_log(get_service_command(service, "stop", "disable")) if exit_code != 0: status_command = get_service_command(service, "status") status_exit_code, status_output = run_command_and_log(status_command) if status_exit_code != 0: output += "Output of '{0}':\n{1}".format(status_command, status_output) if platform.machine() != 'aarch64': # stop ast extensionso that is not started after system reboot. Do not block if it fails. ast_exit_code, disable_output = run_command_and_log(get_service_command("azuremonitor-astextension", "stop", "disable")) if ast_exit_code != 0: hutil_log_info(disable_output) status_command = get_service_command("azuremonitor-astextension", "status") _, ast_status_output = run_command_and_log(status_command) hutil_log_info(ast_status_output) return exit_code, output def update(): """ This function is called when the extension is updated. It marks the uninstall context to indicate that the next run should be treated as an update rather than a clean install. Always returns 0 """ hutil_log_info("Update operation called for Azure Monitor Agent") try: state_dir = os.path.dirname(AmaUninstallContextFile) if not os.path.exists(state_dir): os.makedirs(state_dir) with open(AmaUninstallContextFile, 'w') as f: f.write('update\n') f.write(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) # Timestamp for debugging hutil_log_info("Marked uninstall context as 'update'") except Exception as ex: hutil_log_error("Failed to set uninstall context: {0}\n The uninstall operation will not behave as expected with the uninstall context file missing, defaulting to an uninstall that removes {1}.".format(ex, AmaDataPath)) return 0, "Update succeeded" def _get_uninstall_context(): """ Determine the context of this uninstall operation Returns the context as a string: 'complete' - if this is a clean uninstall 'update' - if this is an update operation Also returns as 'complete' if it fails to read the context file. """ try: if os.path.exists(AmaUninstallContextFile): with open(AmaUninstallContextFile, 'r') as f: context = f.read().strip().split('\n')[0] hutil_log_info("Found uninstall context: {0}".format(context)) return context else: hutil_log_info("Uninstall context file does not exist, defaulting to 'complete'") except Exception as ex: hutil_log_error("Failed to read uninstall context file: {0}\n The uninstall operation will not behave as expected with the uninstall context file missing, defaulting to an uninstall that removes {1}.".format(ex, AmaDataPath)) return 'complete' def _cleanup_uninstall_context(): """ Clean up uninstall context marker """ try: if os.path.exists(AmaUninstallContextFile): os.remove(AmaUninstallContextFile) hutil_log_info("Removed uninstall context file") else: hutil_log_info("Uninstall context file does not exist, nothing to remove") except Exception as ex: hutil_log_error("Failed to cleanup uninstall context: {0}\n This may result in unintended behavior as described.\nIf the marker file exists and cannot be removed, uninstall will continue to keep the {1} path, leading users to have to remove it manually.".format(ex, AmaDataPath)) def restart_launcher(): # start agent launcher hutil_log_info('Handler initiating agent launcher') if is_systemd(): exit_code, output = run_command_and_log('systemctl restart azuremonitor-agentlauncher && systemctl enable azuremonitor-agentlauncher') def restart_astextension(): # start agent transformation extension process hutil_log_info('Handler initiating agent transformation extension (AstExtension) restart and enable') if is_systemd(): exit_code, output = run_command_and_log('systemctl restart azuremonitor-astextension && systemctl enable azuremonitor-astextension') def set_proxy(address, username, password): """ # Set proxy http_proxy env var in dependent services """ try: http_proxy = address address = address.replace("http://","") if username: http_proxy = "http://" + username + ":" + password + "@" + address # Update Coreagent run_command_and_log("mkdir -p /etc/systemd/system/azuremonitor-coreagent.service.d") run_command_and_log("echo '[Service]' > /etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf") run_command_and_log("echo 'Environment=\"http_proxy={0}\"' >> /etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf".format(http_proxy)) run_command_and_log("echo 'Environment=\"https_proxy={0}\"' >> /etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf".format(http_proxy)) os.system('chmod {1} {0}'.format("/etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf", 400)) # Update ME run_command_and_log("mkdir -p /etc/systemd/system/metrics-extension.service.d") run_command_and_log("echo '[Service]' > /etc/systemd/system/metrics-extension.service.d/proxy.conf") run_command_and_log("echo 'Environment=\"http_proxy={0}\"' >> /etc/systemd/system/metrics-extension.service.d/proxy.conf".format(http_proxy)) run_command_and_log("echo 'Environment=\"https_proxy={0}\"' >> /etc/systemd/system/metrics-extension.service.d/proxy.conf".format(http_proxy)) os.system('chmod {1} {0}'.format("/etc/systemd/system/metrics-extension.service.d/proxy.conf", 400)) run_command_and_log("systemctl daemon-reload") run_command_and_log('systemctl restart azuremonitor-coreagent') run_command_and_log('systemctl restart metrics-extension') except: log_and_exit("enable", MissingorInvalidParameterErrorCode, "Failed to update /etc/systemd/system/azuremonitor-coreagent.service.d and /etc/systemd/system/metrics-extension.service.d" ) def unset_proxy(): """ # Unset proxy http_proxy env var in dependent services """ try: hasSettings=False # Update Coreagent if os.path.exists("/etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf"): os.remove("/etc/systemd/system/azuremonitor-coreagent.service.d/proxy.conf") hasSettings=True # Update ME if os.path.exists("/etc/systemd/system/metrics-extension.service.d/proxy.conf"): os.remove("/etc/systemd/system/metrics-extension.service.d/proxy.conf") hasSettings=True if hasSettings: run_command_and_log("systemctl daemon-reload") run_command_and_log('systemctl restart azuremonitor-coreagent') run_command_and_log('systemctl restart metrics-extension') except: log_and_exit("enable", MissingorInvalidParameterErrorCode, "Failed to remove /etc/systemd/system/azuremonitor-coreagent.service.d and /etc/systemd/system/metrics-extension.service.d" ) def get_managed_identity(): """ # Determine Managed Identity (MI) settings # Nomenclature: Managed System Identity (MSI), System-Assigned Identity (SAI), User-Assigned Identity (UAI) # Unspecified MI scenario: MSI returns SAI token if exists, otherwise returns UAI token if exactly one UAI exists, otherwise failure # Specified MI scenario: MSI returns token for specified MI # Returns identifier_name, identifier_value, and error message (if any) """ identifier_name = identifier_value = "" public_settings, _ = get_settings() if public_settings is not None and public_settings.get(AzureMonitorConfigKey): azure_monitor_configuration = public_settings.get(AzureMonitorConfigKey) if azure_monitor_configuration and azure_monitor_configuration.get("enable") == True: public_settings = azure_monitor_configuration.get("configuration") if public_settings is not None and "authentication" in public_settings and "managedIdentity" in public_settings.get("authentication"): managedIdentity = public_settings.get("authentication").get("managedIdentity") if "identifier-name" not in managedIdentity or "identifier-value" not in managedIdentity: return identifier_name, identifier_value, 'Parameters "identifier-name" and "identifier-value" are both required in authentication.managedIdentity public setting' identifier_name = managedIdentity.get("identifier-name") identifier_value = managedIdentity.get("identifier-value") if identifier_name not in ["client_id", "mi_res_id", "object_id"]: return identifier_name, identifier_value, 'Invalid identifier-name provided; must be "client_id" or "mi_res_id" or "object_id"' if not identifier_value: return identifier_name, identifier_value, 'Invalid identifier-value provided; cannot be empty' if identifier_name in ["object_id", "client_id"]: guid_re = re.compile(r'[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}') if not guid_re.search(identifier_value): return identifier_name, identifier_value, 'Invalid identifier-value provided for {0}; must be a GUID'.format(identifier_name) return identifier_name, identifier_value, "" def azureotelcollector_is_active(): """ Checks if `azureotelcollector` is installed to run as a systemd service. """ if is_systemd(): try: rc = subprocess.call(["systemctl", "is-active", "--quiet", "azureotelcollector-watcher.path"]) return rc == 0 except OSError: return False return False def install_azureotelcollector(): """ This method will install the azureotelcollector package and start a systemd file watcher service that watches for configuration file changes. MetricsExtension is responsible for writing the configuration file. Only if configuration is present, otelcollector process will start to run, the watcher service is responsible to monitor the configuration file. """ if is_systemd(): find_package_manager("Install") azureotelcollector_install_command = get_otelcollector_installation_command() hutil_log_info('Running command "{0}"'.format(azureotelcollector_install_command)) # Retry, since install can fail due to concurrent package operations exit_code, output = run_command_with_retries_output( azureotelcollector_install_command, retries = 5, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked ) if exit_code == 0: hutil_log_info('Successfully installed azureotelcollector') return True hutil_log_error('Error installing azureotelcollector "{0}"'.format(output)) return False def get_otelcollector_installation_command(): """ This method provides the installation command to install an azureotelcollector package as a systemd service """ find_package_manager("Install") dir_path = os.getcwd() + "/azureotelcollector/" if PackageManager == "dpkg": package_path = find_otelcollector_package_file(dir_path, "deb") elif PackageManager == "rpm": package_path = find_otelcollector_package_file(dir_path, "rpm") else: raise Exception("Unsupported package manager to install azureotelcollector: {0}.".format(PackageManager)) return "{0} {1} --install {2}".format(PackageManager, PackageManagerOptions, package_path) def find_otelcollector_package_file(directory, pkg_type): """ Finds the otelcollector package in a given path for a given package type using name globbing. """ arch = platform.machine() # Create pattern based on type and arch if pkg_type == "deb": if arch == "x86_64": pattern = "azureotelcollector_*_amd64.deb" elif arch == "aarch64": pattern = "azureotelcollector_*_arm64.deb" else: raise Exception("Unsupported architecture for deb package: {0}".format(arch)) elif pkg_type == "rpm": pattern = "azureotelcollector-*{0}.rpm".format(arch) else: raise Exception("Unsupported package type to install azureotelcollector: {0}".format(pkg_type)) search_pattern = os.path.join(directory, pattern) matches = glob.glob(search_pattern) if not matches: raise IOError("No {0} package found for arch '{1}' in {2} with pattern '{3}'".format(pkg_type, arch, directory, pattern)) # Return the most recently modified match return max(matches, key=os.path.getmtime) def uninstall_azureotelcollector(): """ This method will uninstall azureotelcollector services. No need to stop it separately as the package maintainer script handles it upon uninstalling. """ if is_feature_enabled("enableAzureOTelCollector"): # Only remove azureotelcollector if file exists if os.path.exists("/lib/systemd/system/azureotelcollector-watcher.path"): azureotelcollector_uninstall_command = "" find_package_manager("Uninstall") if PackageManager == "dpkg": azureotelcollector_uninstall_command = "dpkg --purge azureotelcollector" elif PackageManager == "rpm": azureotelcollector_uninstall_command = "rpm --erase azureotelcollector" else: log_and_exit("Uninstall", UnsupportedOperatingSystem, "The OS has neither rpm nor dpkg" ) hutil_log_info('Running command "{0}"'.format(azureotelcollector_uninstall_command)) exit_code, output = run_command_with_retries_output( azureotelcollector_uninstall_command, retries = 5, retry_check = retry_if_dpkg_or_rpm_locked, final_check = final_check_if_dpkg_or_rpm_locked ) if exit_code == 0: hutil_log_info('Successfully removed azureotelcollector') else: hutil_log_error('Error removing azureotelcollector "{0}"'.format(output)) def stop_metrics_process(): if telhandler.is_running(is_lad=False): #Stop the telegraf and ME services tel_out, tel_msg = telhandler.stop_telegraf_service(is_lad=False) if tel_out: hutil_log_info(tel_msg) else: hutil_log_error(tel_msg) #Delete the telegraf and ME services tel_rm_out, tel_rm_msg = telhandler.remove_telegraf_service(is_lad=False) if tel_rm_out: hutil_log_info(tel_rm_msg) else: hutil_log_error(tel_rm_msg) if me_handler.is_running(is_lad=False): me_out, me_msg = me_handler.stop_metrics_service(is_lad=False) if me_out: hutil_log_info(me_msg) else: hutil_log_error(me_msg) me_rm_out, me_rm_msg = me_handler.remove_metrics_service(is_lad=False) if me_rm_out: hutil_log_info(me_rm_msg) else: hutil_log_error(me_rm_msg) pids_filepath = os.path.join(os.getcwd(),'amametrics.pid') # kill existing metrics watcher if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA metrics watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-metrics") >= 0: kill_cmd = "kill " + pid run_command_and_log(kill_cmd) run_command_and_log("rm "+pids_filepath) def stop_syslogconfig_process(): pids_filepath = os.path.join(os.getcwd(),'amasyslogconfig.pid') # kill existing syslog config watcher if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA syslog watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-syslogconfig") >= 0: kill_cmd = "kill " + pid run_command_and_log(kill_cmd) run_command_and_log("rm "+ pids_filepath) def is_metrics_process_running(): pids_filepath = os.path.join(os.getcwd(),'amametrics.pid') if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA metrics watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-metrics") >= 0: return True return False def is_syslogconfig_process_running(): pids_filepath = os.path.join(os.getcwd(),'amasyslogconfig.pid') if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA syslog watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-syslogconfig") >= 0: return True return False def is_transformconfig_process_running(): pids_filepath = os.path.join(os.getcwd(),'amatransformconfig.pid') if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA transform config watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-transformconfig") >= 0: return True return False def start_metrics_process(): """ Start metrics process that performs periodic monitoring activities :return: None """ # if metrics process is already running, it should manage lifecycle of telegraf, ME, # process to refresh ME MSI token and look for new config changes if counters change, etc, so this is no-op if not is_metrics_process_running(): stop_metrics_process() # Start metrics watcher ama_path = os.path.join(os.getcwd(), 'agent.py') args = [sys.executable, ama_path, '-metrics'] log = open(os.path.join(os.getcwd(), 'daemon.log'), 'w') hutil_log_info('start watcher process '+str(args)) subprocess.Popen(args, stdout=log, stderr=log) def start_syslogconfig_process(): """ Start syslog check process that performs periodic DCR monitoring activities and looks for syslog config changes :return: None """ # test if not is_syslogconfig_process_running(): stop_syslogconfig_process() # Start syslog config watcher ama_path = os.path.join(os.getcwd(), 'agent.py') args = [sys.executable, ama_path, '-syslogconfig'] log = open(os.path.join(os.getcwd(), 'daemon.log'), 'w') hutil_log_info('start syslog watcher process '+str(args)) subprocess.Popen(args, stdout=log, stderr=log) def start_transformconfig_process(): """ Start agent transform check process that performs periodic DCR monitoring activities and looks for agent transformation config changes :return: None """ # test if not is_transformconfig_process_running(): stop_transformconfig_process() # Start agent transform config watcher ama_path = os.path.join(os.getcwd(), 'agent.py') args = [sys.executable, ama_path, '-transformconfig'] log = open(os.path.join(os.getcwd(), 'daemon.log'), 'w') hutil_log_info('start agent transform config watcher process '+str(args)) subprocess.Popen(args, stdout=log, stderr=log) def stop_transformconfig_process(): pids_filepath = os.path.join(os.getcwd(),'amatransformconfig.pid') # kill existing agent transform config watcher if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to AMA tranform config watcher. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if len(cmdline) > 0 and cmdline[0].find("agent.py") >= 0 and cmdline[0].find("-transformconfig") >= 0: kill_cmd = "kill " + pid run_command_and_log(kill_cmd) run_command_and_log("rm "+ pids_filepath) def metrics_watcher(hutil_error, hutil_log): """ Watcher thread to monitor metric configuration changes and to take action on them """ global MDSDFluentPort # Check every 30 seconds sleepTime = 30 # Retrieve managed identity info that may be needed for token retrieval identifier_name, identifier_value, error_msg = get_managed_identity() if error_msg: hutil_error('Failed to determine managed identity settings; MSI token retreival will rely on default identity, if any. {0}.'.format(error_msg)) if identifier_name and identifier_value: managed_identity_str = "uai#{0}#{1}".format(identifier_name, identifier_value) else: managed_identity_str = "sai" # Sleep before starting the monitoring time.sleep(sleepTime) last_crc = None last_crc_fluent = None me_msi_token_expiry_epoch = None enabled_me_CMv2_mode = False log_messages = "" while True: try: if not azureotelcollector_is_active(): install_azureotelcollector() if not me_handler.is_running(is_lad=False): me_service_template_path = os.getcwd() + "/services/metrics-extension.service" try: if is_feature_enabled("enableAzureOTelCollector"): if os.path.exists(me_service_template_path): os.remove(me_service_template_path) copyfile(os.getcwd() + "/services/metrics-extension-cmv2.service", me_service_template_path) me_handler.setup_me( is_lad=False, managed_identity=managed_identity_str, HUtilObj=HUtilObject, is_local_control_channel=False, user="azuremetricsext", group="azuremonitoragent") enabled_me_CMv2_mode, log_messages = me_handler.start_metrics_cmv2() elif is_feature_enabled("enableCMV2"): if os.path.exists(me_service_template_path): os.remove(me_service_template_path) copyfile(os.getcwd() + "/services/metrics-extension-otlp.service", me_service_template_path) me_handler.setup_me( is_lad=False, managed_identity=managed_identity_str, HUtilObj=HUtilObject, is_local_control_channel=False) enabled_me_CMv2_mode, log_messages = me_handler.start_metrics_cmv2() except Exception as e: hutil_log_error("Error in setting up metrics-extension.service in CMv2 mode. Exception={0}".format(e)) if enabled_me_CMv2_mode: hutil_log_info("Successfully started metrics-extension.") elif log_messages: hutil_log_error(log_messages) # update fluent config for fluent port if needed fluent_port = '' if os.path.isfile(AMAFluentPortFilePath): f = open(AMAFluentPortFilePath, "r") fluent_port = validate_port_number(f.read(), "fluent") f.close() if fluent_port != '' and os.path.isfile(FluentCfgPath) and fluent_port != MDSDFluentPort: portSetting = " Port " + fluent_port + "\n" defaultPortSetting = 'Port' portUpdated = True with open(FluentCfgPath, 'r') as f: for line in f: found = re.search(r'^\s{0,}Port\s{1,}' + fluent_port + '$', line) if found: portUpdated = False if portUpdated == True: with contextlib.closing(fileinput.FileInput(FluentCfgPath, inplace=True, backup='.bak')) as file: for line in file: if defaultPortSetting in line: print(portSetting, end='') else: print(line, end='') os.chmod(FluentCfgPath, stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) MDSDFluentPort = fluent_port # add SELinux rules if needed if os.path.exists('/etc/selinux/config') and fluent_port != '': sedisabled, _ = run_command_and_log('getenforce | grep -i "Disabled"',log_cmd=False, log_output=False) if sedisabled != 0: check_semanage, _ = run_command_and_log("which semanage",log_cmd=False, log_output=False) if check_semanage == 0: fluentPortEnabled, _ = run_command_and_log('grep -Rnw /var/lib/selinux -e ' + fluent_port,log_cmd=False, log_output=False) if fluentPortEnabled != 0: # also check SELinux config paths for Oracle/RH fluentPortEnabled, _ = run_command_and_log('grep -Rnw /etc/selinux -e ' + fluent_port,log_cmd=False, log_output=False) if fluentPortEnabled != 0: # allow the fluent port in SELinux run_command_and_log('semanage port -a -t http_port_t -p tcp ' + fluent_port,log_cmd=False, log_output=False) if os.path.isfile(FluentCfgPath): f = open(FluentCfgPath, "r") data = f.read() if (data != ''): crc_fluent = hashlib.sha256(data.encode('utf-8')).hexdigest() if (crc_fluent != last_crc_fluent): restart_launcher() last_crc_fluent = crc_fluent if os.path.isfile(MdsdCounterJsonPath): f = open(MdsdCounterJsonPath, "r") data = f.read() if (data != ''): json_data = json.loads(data) if len(json_data) == 0: last_crc = hashlib.sha256(data.encode('utf-8')).hexdigest() if telhandler.is_running(is_lad=False): # Stop the telegraf and ME services tel_out, tel_msg = telhandler.stop_telegraf_service(is_lad=False) if tel_out: hutil_log(tel_msg) else: hutil_error(tel_msg) # Delete the telegraf and ME services tel_rm_out, tel_rm_msg = telhandler.remove_telegraf_service(is_lad=False) if tel_rm_out: hutil_log(tel_rm_msg) else: hutil_error(tel_rm_msg) if not enabled_me_CMv2_mode and me_handler.is_running(is_lad=False): me_out, me_msg = me_handler.stop_metrics_service(is_lad=False) if me_out: hutil_log(me_msg) else: hutil_error(me_msg) me_rm_out, me_rm_msg = me_handler.remove_metrics_service(is_lad=False) if me_rm_out: hutil_log(me_rm_msg) else: hutil_error(me_rm_msg) else: crc = hashlib.sha256(data.encode('utf-8')).hexdigest() if(crc != last_crc): # Resetting the me_msi_token_expiry_epoch variable if we set up ME again. me_msi_token_expiry_epoch = None hutil_log("Start processing metric configuration") hutil_log(data) telegraf_config, telegraf_namespaces = telhandler.handle_config( json_data, "unix:///run/azuremetricsext/mdm_influxdb.socket", "unix:///run/azuremonitoragent/default_influx.socket", is_lad=False) start_telegraf_res, log_messages = telhandler.start_telegraf(is_lad=False) if start_telegraf_res: hutil_log("Successfully started metrics-sourcer.") else: hutil_error(log_messages) if not enabled_me_CMv2_mode: me_service_template_path = os.getcwd() + "/services/metrics-extension.service" if os.path.exists(me_service_template_path): os.remove(me_service_template_path) copyfile(os.getcwd() + "/services/metrics-extension-cmv1.service", me_service_template_path) me_handler.setup_me(is_lad=False, managed_identity=managed_identity_str, HUtilObj=HUtilObject) start_metrics_out, log_messages = me_handler.start_metrics(is_lad=False, managed_identity=managed_identity_str) if start_metrics_out: hutil_log("Successfully started metrics-extension.") else: hutil_error(log_messages) last_crc = crc generate_token = False me_token_path = os.path.join(os.getcwd(), "/config/metrics_configs/AuthToken-MSI.json") if me_msi_token_expiry_epoch is None or me_msi_token_expiry_epoch == "": if os.path.isfile(me_token_path): with open(me_token_path, "r") as f: authtoken_content = f.read() if authtoken_content and "expires_on" in authtoken_content: me_msi_token_expiry_epoch = authtoken_content["expires_on"] else: generate_token = True else: generate_token = True if me_msi_token_expiry_epoch: currentTime = datetime.datetime.now() token_expiry_time = datetime.datetime.fromtimestamp(int(me_msi_token_expiry_epoch)) if token_expiry_time - currentTime < datetime.timedelta(minutes=30): # The MSI Token will expire within 30 minutes. We need to refresh the token generate_token = True if generate_token: generate_token = False msi_token_generated, me_msi_token_expiry_epoch, log_messages = me_handler.generate_MSI_token(identifier_name, identifier_value, is_lad=False) if msi_token_generated: hutil_log("Successfully refreshed metrics-extension MSI Auth token.") else: hutil_error(log_messages) telegraf_restart_retries = 0 me_restart_retries = 0 max_restart_retries = 10 # Check if telegraf is running, if not, then restart if not telhandler.is_running(is_lad=False): if telegraf_restart_retries < max_restart_retries: telegraf_restart_retries += 1 hutil_log("Telegraf binary process is not running. Restarting telegraf now. Retry count - {0}".format(telegraf_restart_retries)) tel_out, tel_msg = telhandler.stop_telegraf_service(is_lad=False) if tel_out: hutil_log(tel_msg) else: hutil_error(tel_msg) start_telegraf_res, log_messages = telhandler.start_telegraf(is_lad=False) if start_telegraf_res: hutil_log("Successfully started metrics-sourcer.") else: hutil_error(log_messages) else: hutil_error("Telegraf binary process is not running. Failed to restart after {0} retries. Please check telegraf.log".format(max_restart_retries)) else: telegraf_restart_retries = 0 # Check if ME is running, if not, then restart if not me_handler.is_running(is_lad=False): if me_restart_retries < max_restart_retries: me_restart_retries += 1 hutil_log("MetricsExtension binary process is not running. Restarting MetricsExtension now. Retry count - {0}".format(me_restart_retries)) me_out, me_msg = me_handler.stop_metrics_service(is_lad=False) if me_out: hutil_log(me_msg) else: hutil_error(me_msg) start_metrics_out, log_messages = me_handler.start_metrics(is_lad=False, managed_identity=managed_identity_str) if start_metrics_out: hutil_log("Successfully started metrics-extension.") else: hutil_error(log_messages) else: hutil_error("MetricsExtension binary process is not running. Failed to restart after {0} retries. Please check /var/log/syslog for ME logs".format(max_restart_retries)) else: me_restart_retries = 0 except IOError as e: hutil_error('I/O error in setting up or monitoring metrics. Exception={0}'.format(e)) except Exception as e: hutil_error('Error in setting up or monitoring metrics. Exception={0}'.format(e)) finally: time.sleep(sleepTime) def syslogconfig_watcher(hutil_error, hutil_log): """ Watcher thread to monitor syslog configuration changes and to take action on them """ syslog_enabled = False # Check for config changes every 30 seconds sleepTime = 30 # Sleep before starting the monitoring time.sleep(sleepTime) GcsEnabled, McsEnabled = get_control_plane_mode() while True: try: if os.path.isfile(AMASyslogConfigMarkerPath): f = open(AMASyslogConfigMarkerPath, "r") data = f.read() if (data != ''): if "true" in data: syslog_enabled = True f.close() elif GcsEnabled: # 1P Syslog is always enabled as each tenant could be having different mdsd.xml configuration syslog_enabled = True if syslog_enabled: # place syslog local configs syslog_enabled = False generate_localsyslog_configs(uses_gcs=GcsEnabled, uses_mcs=McsEnabled) else: # remove syslog local configs remove_localsyslog_configs() except IOError as e: hutil_error('I/O error in setting up syslog config watcher. Exception={0}'.format(e)) except Exception as e: hutil_error('Error in setting up syslog config watcher. Exception={0}'.format(e)) finally: time.sleep(sleepTime) def transformconfig_watcher(hutil_error, hutil_log): """ Watcher thread to monitor agent transformation configuration changes and to take action on them """ # Check for config changes every 30 seconds sleepTime = 30 # Sleep before starting the monitoring time.sleep(sleepTime) last_crc = None while True: try: if os.path.isfile(AMAAstTransformConfigMarkerPath): f = open(AMAAstTransformConfigMarkerPath, "r") data = f.read() if (data != ''): crc = hashlib.sha256(data.encode('utf-8')).hexdigest() if (crc != last_crc): restart_astextension() last_crc = crc f.close() except IOError as e: hutil_error('I/O error in setting up agent transform config watcher. Exception={0}'.format(e)) except Exception as e: hutil_error('Error in setting up agent transform config watcher. Exception={0}'.format(e)) finally: time.sleep(sleepTime) def generate_localsyslog_configs(uses_gcs = False, uses_mcs = False): """ Install local syslog configuration files if not present and restart syslog """ global MDSDSyslogPort # don't deploy any configuration if no control plane is configured if not uses_gcs and not uses_mcs: return public_settings, _ = get_settings() syslog_port = '' if os.path.isfile(AMASyslogPortFilePath): f = open(AMASyslogPortFilePath, "r") syslog_port = validate_port_number(f.read(), "syslog") f.close() useSyslogTcp = False if syslog_port == MDSDSyslogPort: return # always use syslog tcp port, unless # - the distro is Red Hat based and doesn't have semanage # these distros seem to have SELinux on by default and we shouldn't be installing semanage ourselves if not os.path.exists('/etc/selinux/config'): useSyslogTcp = True else: sedisabled, _ = run_command_and_log('getenforce | grep -i "Disabled"',log_cmd=False, log_output=False) if sedisabled == 0: useSyslogTcp = True else: check_semanage, _ = run_command_and_log("which semanage",log_cmd=False, log_output=False) if check_semanage == 0 and syslog_port != '': syslogPortEnabled, _ = run_command_and_log('grep -Rnw /var/lib/selinux -e ' + syslog_port,log_cmd=False, log_output=False) if syslogPortEnabled != 0: # also check SELinux config paths for Oracle/RH syslogPortEnabled, _ = run_command_and_log('grep -Rnw /etc/selinux -e ' + syslog_port,log_cmd=False, log_output=False) if syslogPortEnabled != 0: # allow the syslog port in SELinux run_command_and_log('semanage port -a -t syslogd_port_t -p tcp ' + syslog_port,log_cmd=False, log_output=False) useSyslogTcp = True if syslog_port != '': MDSDSyslogPort = syslog_port # 1P tenants use omuxsock, so keep using that for customers using 1P if useSyslogTcp == True and syslog_port != '': if os.path.exists('/etc/rsyslog.d/'): restartRequired = False if uses_gcs and not os.path.exists('/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf'): copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/rsyslogconf/05-azuremonitoragent-loadomuxsock.conf","/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf") restartRequired = True if not os.path.exists('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf'): if os.path.exists('/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf'): os.remove("/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf") if os.path.exists('/etc/rsyslog.d/10-azuremonitoragent.conf'): os.remove("/etc/rsyslog.d/10-azuremonitoragent.conf") copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/rsyslogconf/10-azuremonitoragent-omfwd.conf","/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf") os.chmod('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) restartRequired = True portSetting = 'Port="' + syslog_port + '"' defaultPortSetting = 'Port="28330"' portUpdated = False with open('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf') as f: if portSetting not in f.read(): portUpdated = True if portUpdated == True: copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/rsyslogconf/10-azuremonitoragent-omfwd.conf","/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf") with contextlib.closing(fileinput.FileInput('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf', inplace=True, backup='.bak')) as file: for line in file: print(line.replace(defaultPortSetting, portSetting), end='') os.chmod('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) restartRequired = True if restartRequired == True: run_command_and_log(get_service_command("rsyslog", "restart")) hutil_log_info("Installed local syslog configuration files and restarted syslog") if os.path.exists('/etc/syslog-ng/syslog-ng.conf'): restartRequired = False if not os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf'): if os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent.conf'): os.remove("/etc/syslog-ng/conf.d/azuremonitoragent.conf") syslog_ng_confpath = os.path.join('/etc/syslog-ng/', 'conf.d') if not os.path.exists(syslog_ng_confpath): os.makedirs(syslog_ng_confpath) copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/syslog-ngconf/azuremonitoragent-tcp.conf","/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf") os.chmod('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) restartRequired = True portSetting = "port(" + syslog_port + ")" defaultPortSetting = "port(28330)" portUpdated = False with open('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf') as f: if portSetting not in f.read(): portUpdated = True if portUpdated == True: copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/syslog-ngconf/azuremonitoragent-tcp.conf","/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf") with contextlib.closing(fileinput.FileInput('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf', inplace=True, backup='.bak')) as file: for line in file: print(line.replace(defaultPortSetting, portSetting), end='') os.chmod('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) restartRequired = True if restartRequired == True: run_command_and_log(get_service_command("syslog-ng", "restart")) hutil_log_info("Installed local syslog configuration files and restarted syslog") else: if os.path.exists('/etc/rsyslog.d/') and not os.path.exists('/etc/rsyslog.d/10-azuremonitoragent.conf'): if os.path.exists('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf'): os.remove("/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf") copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/rsyslogconf/05-azuremonitoragent-loadomuxsock.conf","/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf") copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/rsyslogconf/10-azuremonitoragent.conf","/etc/rsyslog.d/10-azuremonitoragent.conf") os.chmod('/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) os.chmod('/etc/rsyslog.d/10-azuremonitoragent.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) run_command_and_log(get_service_command("rsyslog", "restart")) hutil_log_info("Installed local syslog configuration files and restarted syslog") if os.path.exists('/etc/syslog-ng/syslog-ng.conf') and not os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent.conf'): if os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf'): os.remove("/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf") syslog_ng_confpath = os.path.join('/etc/syslog-ng/', 'conf.d') if not os.path.exists(syslog_ng_confpath): os.makedirs(syslog_ng_confpath) copyfile("/etc/opt/microsoft/azuremonitoragent/syslog/syslog-ngconf/azuremonitoragent.conf","/etc/syslog-ng/conf.d/azuremonitoragent.conf") os.chmod('/etc/syslog-ng/conf.d/azuremonitoragent.conf', stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH) run_command_and_log(get_service_command("syslog-ng", "restart")) hutil_log_info("Installed local syslog configuration files and restarted syslog") def remove_localsyslog_configs(): """ Remove local syslog configuration files if present and restart syslog """ if os.path.exists('/etc/rsyslog.d/10-azuremonitoragent.conf') or os.path.exists('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf'): if os.path.exists('/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf'): os.remove("/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf") if os.path.exists('/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf'): os.remove("/etc/rsyslog.d/05-azuremonitoragent-loadomuxsock.conf") if os.path.exists('/etc/rsyslog.d/10-azuremonitoragent.conf'): os.remove("/etc/rsyslog.d/10-azuremonitoragent.conf") run_command_and_log(get_service_command("rsyslog", "restart")) hutil_log_info("Removed local syslog configuration files if found and restarted syslog") if os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent.conf') or os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf'): if os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf'): os.remove("/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf") if os.path.exists('/etc/syslog-ng/conf.d/azuremonitoragent.conf'): os.remove("/etc/syslog-ng/conf.d/azuremonitoragent.conf") run_command_and_log(get_service_command("syslog-ng", "restart")) hutil_log_info("Removed local syslog configuration files if found and restarted syslog") def metrics(): """ Take care of setting up telegraf and ME for metrics if configuration is present """ pids_filepath = os.path.join(os.getcwd(), 'amametrics.pid') py_pid = os.getpid() with open(pids_filepath, 'w') as f: f.write(str(py_pid) + '\n') watcher_thread = Thread(target = metrics_watcher, args = [hutil_log_error, hutil_log_info]) watcher_thread.start() watcher_thread.join() return 0, "" def syslogconfig(): """ Take care of setting up syslog configuration change watcher """ pids_filepath = os.path.join(os.getcwd(), 'amasyslogconfig.pid') py_pid = os.getpid() with open(pids_filepath, 'w') as f: f.write(str(py_pid) + '\n') watcher_thread = Thread(target = syslogconfig_watcher, args = [hutil_log_error, hutil_log_info]) watcher_thread.start() watcher_thread.join() return 0, "" def transformconfig(): """ Take care of setting up agent transformation configuration change watcher """ pids_filepath = os.path.join(os.getcwd(), 'amatransformconfig.pid') py_pid = os.getpid() with open(pids_filepath, 'w') as f: f.write(str(py_pid) + '\n') watcher_thread = Thread(target = transformconfig_watcher, args = [hutil_log_error, hutil_log_info]) watcher_thread.start() watcher_thread.join() return 0, "" # Dictionary of operations strings to methods operations = {'Disable' : disable, 'Uninstall' : uninstall, 'Install' : install, 'Enable' : enable, 'Update' : update, 'Metrics' : metrics, 'Syslogconfig' : syslogconfig, 'Transformconfig' : transformconfig } def parse_context(operation): """ Initialize a HandlerUtil object for this operation. If the required modules have not been imported, this will return None. """ hutil = None if ('Utils.WAAgentUtil' in sys.modules and 'Utils.HandlerUtil' in sys.modules): try: logFileName = 'extension.log' hutil = HUtil.HandlerUtility(waagent.Log, waagent.Error, logFileName=logFileName) hutil.do_parse_context(operation) # As per VM extension team, we have to manage rotation for our extension.log # for now, this is our extension code, but to be moved to HUtil library. if os.path.exists(WAGuestAgentLogRotateFilePath): if os.path.exists(AMAExtensionLogRotateFilePath): try: os.remove(AMAExtensionLogRotateFilePath) except Exception as ex: output = 'Logrotate removal failed with error: {0}\nStacktrace: {1}'.format(ex, traceback.format_exc()) hutil_log_info(output) else: if not os.path.exists(AMAExtensionLogRotateFilePath): logrotateFilePath = os.path.join(os.getcwd(), 'azuremonitoragentextension.logrotate') copyfile(logrotateFilePath,AMAExtensionLogRotateFilePath) # parse_context may throw KeyError if necessary JSON key is not # present in settings except KeyError as e: waagent_log_error('Unable to parse context with error: ' \ '{0}'.format(e)) raise ParameterMissingException return hutil def set_os_arch(operation): """ Checks if the current system architecture is present in the SupportedArch set and replaces the package names accordingly """ global BundleFileName, SupportedArch current_arch = platform.machine() if current_arch in SupportedArch: # Replace the AMA package name according to architecture BundleFileName = BundleFileName.replace('x86_64', current_arch) def find_package_manager(operation): """ Checks if the dist is debian based or centos based and assigns the package manager accordingly """ global PackageManager, PackageManagerOptions, BundleFileName dist, _ = find_vm_distro(operation) dpkg_set = set(["debian", "ubuntu"]) rpm_set = set(["oracle", "ol", "redhat", "centos", "red hat", "suse", "sles", "opensuse", "cbl-mariner", "mariner", "azurelinux", "rhel", "rocky", "alma", "amzn"]) for dpkg_dist in dpkg_set: if dist.startswith(dpkg_dist): PackageManager = "dpkg" # OK to replace the /etc/default/azuremonitoragent, since the placeholders gets replaced again. # Otherwise, the package manager prompts for action (Y/I/N/O/D/Z) [default=N] PackageManagerOptions = "--force-overwrite --force-confnew" BundleFileName = BundleFileNameDeb break for rpm_dist in rpm_set: if dist.startswith(rpm_dist): PackageManager = "rpm" # Same as above. PackageManagerOptions = "--force" BundleFileName = BundleFileNameRpm break if PackageManager == "": log_and_exit(operation, UnsupportedOperatingSystem, "The OS has neither rpm nor dpkg" ) def find_vm_distro(operation): """ Finds the Linux Distribution this VM is running on by directly parsing distribution-specific files for reliable detection. """ vm_dist = vm_ver = "" detection_files_checked = [] # Try to read from /etc/os-release first (most modern distributions) if os.path.exists('/etc/os-release'): detection_files_checked.append('/etc/os-release') try: with open('/etc/os-release', 'r') as fp: os_release = {} for line in fp: if line.strip() and '=' in line: k, v = line.strip().split('=', 1) os_release[k] = v.strip('"\'').strip() if 'ID' in os_release: vm_dist = os_release['ID'].lower() # Clean up the ID by removing any vendor-specific suffixes vm_dist = vm_dist.split('-')[0] if 'VERSION_ID' in os_release: vm_ver = os_release['VERSION_ID'].lower() # Fallback for ID_LIKE if direct ID isn't recognized if not vm_dist and 'ID_LIKE' in os_release: # Get first value from ID_LIKE vm_dist = os_release['ID_LIKE'].lower().split()[0].strip('"\'') vm_dist = vm_dist.split('-')[0] hutil_log_info("OS detected from /etc/os-release: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/os-release: {0}".format(str(e))) # If we couldn't get the distribution from /etc/os-release, try other files if not vm_dist or not vm_ver: # Try /etc/system-release first (used by Amazon Linux and others) if os.path.exists('/etc/system-release'): detection_files_checked.append('/etc/system-release') try: with open('/etc/system-release', 'r') as fp: content = fp.read().lower() if 'amazon' in content: vm_dist = 'amzn' # Try to extract version version_match = re.search(r'release\s+(\d+(\.\d+)?)', content) if version_match: vm_ver = version_match.group(1) hutil_log_info("OS detected from /etc/system-release: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/system-release: {0}".format(str(e))) # SUSE specific detection if not vm_dist and os.path.exists('/etc/SuSE-release'): detection_files_checked.append('/etc/SuSE-release') try: with open('/etc/SuSE-release', 'r') as fp: content = fp.read() if 'SUSE Linux Enterprise Server' in content: vm_dist = 'sles' elif 'openSUSE' in content: vm_dist = 'opensuse' else: vm_dist = 'suse' # Try to extract the version version_match = re.search(r'VERSION\s*=\s*(\d+)', content) if version_match: vm_ver = version_match.group(1) # Also look for service pack level sp_match = re.search(r'PATCHLEVEL\s*=\s*(\d+)', content) if sp_match and vm_ver: vm_ver = '{0}.{1}'.format(vm_ver, sp_match.group(1)) hutil_log_info("OS detected from /etc/SuSE-release: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/SuSE-release: {0}".format(str(e))) # Red Hat based systems if not vm_dist and os.path.exists('/etc/redhat-release'): detection_files_checked.append('/etc/redhat-release') try: with open('/etc/redhat-release', 'r') as fp: content = fp.read().lower() if 'red hat' in content: vm_dist = 'redhat' elif 'centos' in content: vm_dist = 'centos' elif 'oracle' in content: vm_dist = 'oracle' elif 'fedora' in content: vm_dist = 'fedora' elif 'rocky' in content: vm_dist = 'rocky' elif 'alma' in content: vm_dist = 'alma' else: vm_dist = 'redhat' # Default to redhat for RHEL-based systems # Try to extract version using a more flexible pattern # This handles formats like "release 8.6" or "release 7.9.2009" version_match = re.search(r'release\s+(\d+(\.\d+){0,2})', content) if version_match: vm_ver = version_match.group(1) hutil_log_info("OS detected from /etc/redhat-release: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/redhat-release: {0}".format(str(e))) # Debian based systems with lsb-release if not vm_dist and os.path.exists('/etc/lsb-release'): detection_files_checked.append('/etc/lsb-release') try: lsb_data = {} with open('/etc/lsb-release', 'r') as fp: for line in fp: if line.strip() and '=' in line: k, v = line.strip().split('=', 1) lsb_data[k] = v.strip('"\'') if 'DISTRIB_ID' in lsb_data: vm_dist = lsb_data['DISTRIB_ID'].lower() if 'DISTRIB_RELEASE' in lsb_data: vm_ver = lsb_data['DISTRIB_RELEASE'].lower() hutil_log_info("OS detected from /etc/lsb-release: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/lsb-release: {0}".format(str(e))) # Debian specific detection if not vm_dist and os.path.exists('/etc/debian_version'): detection_files_checked.append('/etc/debian_version') try: with open('/etc/debian_version', 'r') as fp: vm_ver = fp.read().strip() vm_dist = 'debian' hutil_log_info("OS detected from /etc/debian_version: {0} {1}".format(vm_dist, vm_ver)) except Exception as e: hutil_log_error("Error reading /etc/debian_version: {0}".format(str(e))) # Final fallback - try /proc/version if not vm_dist and os.path.exists('/proc/version'): detection_files_checked.append('/proc/version') try: with open('/proc/version', 'r') as fp: content = fp.read().lower() if 'debian' in content: vm_dist = 'debian' elif 'ubuntu' in content: vm_dist = 'ubuntu' elif 'red hat' in content or 'redhat' in content: vm_dist = 'redhat' elif 'suse' in content: vm_dist = 'suse' # Try to extract version - not always reliable from /proc/version hutil_log_info("OS detected from /proc/version: {0}".format(vm_dist)) except Exception as e: hutil_log_error("Error reading /proc/version: {0}".format(str(e))) # If we still couldn't determine the OS, log what we tried and throw an error if not vm_dist: error_msg = 'Indeterminate operating system. Files checked: {0}'.format(", ".join(detection_files_checked)) log_and_exit(operation, IndeterminateOperatingSystem, error_msg) # Normalize distribution names if vm_dist == 'rhel' or vm_dist == 'red hat': vm_dist = 'redhat' elif vm_dist == 'ol': vm_dist = 'oracle' if vm_ver and '.' in vm_ver and vm_dist != 'ubuntu': # For Ubuntu, keep major.minor format (e.g., "18.04") # For other distributions, extract only the major version # This is needed for matching with supported_distros.py vm_ver = vm_ver.split('.')[0] # Add debugging info hutil_log_info("Final OS detection result: {0} {1}".format(vm_dist.lower(), vm_ver.lower())) return vm_dist.lower(), vm_ver.lower() def is_vm_supported_for_extension(operation): """ Checks if the VM this extension is running on is supported by AzureMonitorAgent Returns for platform.linux_distribution() vary widely in format, such as '7.3.1611' returned for a VM with CentOS 7, so the first provided digits must match The supported distros of the AzureMonitorLinuxAgent are allowed to utilize this VM extension. All other distros will get error code 51 """ if platform.machine() == 'aarch64': supported_dists = supported_distros.supported_dists_aarch64 else: supported_dists = supported_distros.supported_dists_x86_64 vm_supported = False vm_dist, vm_ver = find_vm_distro(operation) # Find this VM distribution in the supported list for supported_dist in list(supported_dists.keys()): if not vm_dist.startswith(supported_dist): continue # Check if this VM distribution version is supported vm_ver_split = vm_ver.split('.') for supported_ver in supported_dists[supported_dist]: supported_ver_split = supported_ver.split('.') # If vm_ver is at least as precise (at least as many digits) as # supported_ver and matches all the supported_ver digits, then # this VM is guaranteed to be supported vm_ver_match = True for idx, supported_ver_num in enumerate(supported_ver_split): try: supported_ver_num = int(supported_ver_num) vm_ver_num = int(vm_ver_split[idx]) except IndexError: vm_ver_match = False break if vm_ver_num != supported_ver_num: vm_ver_match = False break if vm_ver_match: vm_supported = True break if vm_supported: break return vm_supported, vm_dist, vm_ver def exit_if_vm_not_supported(operation): """ Check if this VM distro and version are supported by the AzureMonitorLinuxAgent. If VM is supported, find the package manager present in this distro If this VM is not supported, log the proper error code and exit. """ vm_supported, vm_dist, vm_ver = is_vm_supported_for_extension(operation) if not vm_supported: log_and_exit(operation, UnsupportedOperatingSystem, 'Unsupported operating system: ' \ '{0} {1}'.format(vm_dist, vm_ver)) return 0 def is_feature_enabled(feature): """ Checks if the feature is enabled in the current region """ feature_support_matrix = { 'useDynamicSSL' : ['all'], 'enableCMV2' : ['all'], 'enableAzureOTelCollector' : ['all'] } featurePreviewFlagPath = PreviewFeaturesDirectory + feature if os.path.exists(featurePreviewFlagPath): return True featurePreviewDisabledFlagPath = PreviewFeaturesDirectory + feature + 'Disabled' if os.path.exists(featurePreviewDisabledFlagPath): return False _, region = get_azure_environment_and_region() if feature in feature_support_matrix.keys(): if region in feature_support_matrix[feature] or "all" in feature_support_matrix[feature]: return True return False def get_ssl_cert_info(operation): """ Get the appropriate SSL_CERT_DIR / SSL_CERT_FILE based on the Linux distro """ name = value = None distro, version = find_vm_distro(operation) for name in ['ubuntu', 'debian']: if distro.startswith(name): return 'SSL_CERT_DIR', '/etc/ssl/certs' for name in ['centos', 'redhat', 'red hat', 'oracle', 'ol', 'cbl-mariner', 'mariner', 'azurelinux', 'rhel', 'rocky', 'alma', 'amzn']: if distro.startswith(name): return 'SSL_CERT_FILE', '/etc/pki/tls/certs/ca-bundle.crt' for name in ['suse', 'sles', 'opensuse']: if distro.startswith(name): if version.startswith('12'): return 'SSL_CERT_DIR', '/var/lib/ca-certificates/openssl' elif version.startswith('15') or version.startswith('16'): return 'SSL_CERT_DIR', '/etc/ssl/certs' log_and_exit(operation, GenericErrorCode, 'Unable to determine values for SSL_CERT_DIR or SSL_CERT_FILE') def copy_astextension_binaries(): astextension_bin_local_path = os.getcwd() + "/AstExtensionBin/" astextension_bin = "/opt/microsoft/azuremonitoragent/bin/astextension/" astextension_runtimesbin = "/opt/microsoft/azuremonitoragent/bin/astextension/runtimes/" if os.path.exists(astextension_runtimesbin): # only for versions of AMA with .NET runtimes rmtree(astextension_runtimesbin) # only for versions with Ast .net cleanup .NET files as it is causing issues with AOT runtime for f in os.listdir(astextension_bin): if f != 'AstExtension' and f != 'appsettings.json': os.remove(os.path.join(astextension_bin, f)) for f in os.listdir(astextension_bin_local_path): compare_and_copy_bin(astextension_bin_local_path + f, astextension_bin + f) def is_arc_installed(): """ Check if this is an Arc machine """ # Using systemctl to check this since Arc only supports VMs that have systemd check_arc = os.system('systemctl status himdsd 1>/dev/null 2>&1') return check_arc == 0 def get_arc_endpoint(): """ Find the endpoint for Arc IMDS """ endpoint_filepath = '/lib/systemd/system.conf.d/azcmagent.conf' endpoint = '' try: with open(endpoint_filepath, 'r') as f: data = f.read() endpoint = data.split("\"IMDS_ENDPOINT=")[1].split("\"\n")[0] except: hutil_log_error('Unable to load Arc IMDS endpoint from {0}'.format(endpoint_filepath)) return endpoint def get_imds_endpoint(): """ Find the appropriate endpoint (Azure or Arc) for IMDS """ azure_imds_endpoint = 'http://169.254.169.254/metadata/instance?api-version=2018-10-01' if (is_arc_installed()): hutil_log_info('Arc is installed, loading Arc-specific IMDS endpoint') imds_endpoint = get_arc_endpoint() if imds_endpoint: imds_endpoint += '/metadata/instance?api-version=2019-08-15' else: # Fall back to the traditional IMDS endpoint; the cloud domain and VM # resource id detection logic are resilient to failed queries to IMDS imds_endpoint = azure_imds_endpoint hutil_log_info('Falling back to default Azure IMDS endpoint') else: imds_endpoint = azure_imds_endpoint hutil_log_info('Using IMDS endpoint "{0}"'.format(imds_endpoint)) return imds_endpoint def get_azure_environment_and_region(): """ Retreive the Azure environment and region from Azure or Arc IMDS """ imds_endpoint = get_imds_endpoint() req = urllib.Request(imds_endpoint) req.add_header('Metadata', 'True') environment = region = None try: response = json.loads(urllib.urlopen(req).read().decode('utf-8', 'ignore')) if ('compute' in response): if ('azEnvironment' in response['compute']): environment = response['compute']['azEnvironment'].lower() if ('location' in response['compute']): region = response['compute']['location'].lower() except urlerror.HTTPError as e: hutil_log_error('Request to Metadata service URL failed with an HTTPError: {0}'.format(e)) hutil_log_error('Response from Metadata service: {0}'.format(e.read())) except Exception as e: hutil_log_error('Unexpected error from Metadata service: {0}'.format(e)) hutil_log_info('Detected environment: {0}, region: {1}'.format(environment, region)) return environment, region def run_command_and_log(cmd, check_error = True, log_cmd = True, log_output = True): """ Run the provided shell command and log its output, including stdout and stderr. The output should not contain any PII, but the command might. In this case, log_cmd should be set to False. """ exit_code, output = run_get_output(cmd, check_error, log_cmd) if log_cmd: hutil_log_info('Output of command "{0}": \n{1}'.format(cmd.rstrip(), output)) elif log_output: hutil_log_info('Output: \n{0}'.format(output)) if "cannot open Packages database" in output: # Install failures # External issue. Package manager db is either corrupt or needs cleanup # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = MissingDependency output += "Package manager database is in a bad state. Please recover package manager, db cache and try install again later." elif "Permission denied" in output: # Enable failures # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = MissingDependency return exit_code, output def run_command_with_retries_output(cmd, retries, retry_check, final_check = None, check_error = True, log_cmd = True, initial_sleep_time = InitialRetrySleepSeconds, sleep_increase_factor = 1): """ Caller provides a method, retry_check, to use to determine if a retry should be performed. This must be a function with two parameters: exit_code and output The final_check can be provided as a method to perform a final check after retries have been exhausted Logic used: will retry up to retries times with initial_sleep_time in between tries If the retry_check retuns True for retry_verbosely, we will try cmd with the standard -v verbose flag added """ try_count = 0 sleep_time = initial_sleep_time run_cmd = cmd run_verbosely = False while try_count <= retries: if run_verbosely: run_cmd = cmd + ' -v' exit_code, output = run_command_and_log(run_cmd, check_error, log_cmd) should_retry, retry_message, run_verbosely = retry_check(exit_code, output) if not should_retry: break try_count += 1 hutil_log_info(retry_message) time.sleep(sleep_time) sleep_time *= sleep_increase_factor if final_check is not None: exit_code = final_check(exit_code, output) return exit_code, output def is_dpkg_or_rpm_locked(exit_code, output): """ If dpkg is locked, the output will contain a message similar to 'dpkg status database is locked by another process' """ if exit_code != 0: dpkg_locked_search = r'^.*dpkg.+lock.*$' dpkg_locked_re = re.compile(dpkg_locked_search, re.M) if dpkg_locked_re.search(output): return True rpm_locked_search = r'^.*rpm.+lock.*$' rpm_locked_re = re.compile(rpm_locked_search, re.M) if rpm_locked_re.search(output): return True return False def retry_if_dpkg_or_rpm_locked(exit_code, output): """ Some commands fail because the package manager is locked (apt-get/dpkg only); this will allow retries on failing commands. """ retry_verbosely = False dpkg_or_rpm_locked = is_dpkg_or_rpm_locked(exit_code, output) if dpkg_or_rpm_locked: return True, 'Retrying command because package manager is locked.', \ retry_verbosely else: return False, '', False def final_check_if_dpkg_or_rpm_locked(exit_code, output): """ If dpkg or rpm is still locked after the retries, we want to return a specific error code """ dpkg_or_rpm_locked = is_dpkg_or_rpm_locked(exit_code, output) if dpkg_or_rpm_locked: exit_code = DPKGOrRPMLockedErrorCode return exit_code def get_settings(): """ Retrieve the configuration for this extension operation """ global SettingsDict public_settings = None protected_settings = None if HUtilObject is not None: public_settings = HUtilObject.get_public_settings() protected_settings = HUtilObject.get_protected_settings() elif SettingsDict is not None: public_settings = SettingsDict['public_settings'] protected_settings = SettingsDict['protected_settings'] else: SettingsDict = {} handler_env = get_handler_env() try: config_dir = str(handler_env['handlerEnvironment']['configFolder']) except: config_dir = os.path.join(os.getcwd(), 'config') seq_no = get_latest_seq_no() settings_path = os.path.join(config_dir, '{0}.settings'.format(seq_no)) try: with open(settings_path, 'r') as settings_file: settings_txt = settings_file.read() settings = json.loads(settings_txt) h_settings = settings['runtimeSettings'][0]['handlerSettings'] public_settings = h_settings['publicSettings'] SettingsDict['public_settings'] = public_settings except: hutil_log_error('Unable to load handler settings from ' \ '{0}'.format(settings_path)) if ('protectedSettings' in h_settings and 'protectedSettingsCertThumbprint' in h_settings and h_settings['protectedSettings'] is not None and h_settings['protectedSettingsCertThumbprint'] is not None): encoded_settings = h_settings['protectedSettings'] settings_thumbprint = h_settings['protectedSettingsCertThumbprint'] encoded_cert_path = os.path.join('/var/lib/waagent', '{0}.crt'.format( settings_thumbprint)) encoded_key_path = os.path.join('/var/lib/waagent', '{0}.prv'.format( settings_thumbprint)) decoded_settings = base64.standard_b64decode(encoded_settings) # FIPS 140-3: use 'openssl cms' (supports AES256 & DES_EDE3_CBC) with fallback to legacy 'openssl smime' cms_cmd = 'openssl cms -inform DER -decrypt -recip {0} -inkey {1}'.format(encoded_cert_path, encoded_key_path) smime_cmd = 'openssl smime -inform DER -decrypt -recip {0} -inkey {1}'.format(encoded_cert_path, encoded_key_path) protected_settings_str = None for decrypt_cmd in [cms_cmd, smime_cmd]: try: session = subprocess.Popen([decrypt_cmd], shell=True, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) output = session.communicate(decoded_settings) # success only if return code is 0 and we have output if session.returncode == 0 and output[0]: protected_settings_str = output[0] if decrypt_cmd == cms_cmd: hutil_log_info('Decrypted protectedSettings using openssl cms.') else: hutil_log_info('Decrypted protectedSettings using openssl smime fallback.') break else: hutil_log_info('Attempt to decrypt protectedSettings with "{0}" failed (rc={1}).'.format(decrypt_cmd, session.returncode)) except OSError: pass if protected_settings_str is None: log_and_exit('Enable', GenericErrorCode, 'Failed decrypting protectedSettings') protected_settings = '' try: protected_settings = json.loads(protected_settings_str) except: hutil_log_error('JSON exception decoding protected settings') SettingsDict['protected_settings'] = protected_settings return public_settings, protected_settings def update_status_file(operation, exit_code, exit_status, message): """ Mimic HandlerUtil method do_status_report in case hutil method is not available Write status to status file """ handler_env = get_handler_env() try: extension_version = str(handler_env['version']) status_dir = str(handler_env['handlerEnvironment']['statusFolder']) except: extension_version = "1.0" status_dir = os.path.join(os.getcwd(), 'status') status_txt = [{ "version" : extension_version, "timestampUTC" : time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "status" : { "name" : "Microsoft.Azure.Monitor.AzureMonitorLinuxAgent", "operation" : operation, "status" : exit_status, "code" : exit_code, "formattedMessage" : { "lang" : "en-US", "message" : message } } }] status_json = json.dumps(status_txt) # Find the most recently changed config file and then use the # corresponding status file latest_seq_no = get_latest_seq_no() status_path = os.path.join(status_dir, '{0}.status'.format(latest_seq_no)) status_tmp = '{0}.tmp'.format(status_path) with open(status_tmp, 'w+') as tmp_file: tmp_file.write(status_json) os.rename(status_tmp, status_path) def get_handler_env(): """ Set and retrieve the contents of HandlerEnvironment.json as JSON """ global HandlerEnvironment if HandlerEnvironment is None: handler_env_path = os.path.join(os.getcwd(), 'HandlerEnvironment.json') try: with open(handler_env_path, 'r') as handler_env_file: handler_env_txt = handler_env_file.read() handler_env = json.loads(handler_env_txt) if type(handler_env) == list: handler_env = handler_env[0] HandlerEnvironment = handler_env except Exception as e: waagent_log_error(str(e)) return HandlerEnvironment def get_latest_seq_no(): """ Determine the latest operation settings number to use """ global SettingsSequenceNumber if SettingsSequenceNumber is None: handler_env = get_handler_env() try: config_dir = str(handler_env['handlerEnvironment']['configFolder']) except: config_dir = os.path.join(os.getcwd(), 'config') latest_seq_no = -1 cur_seq_no = -1 latest_time = None try: for dir_name, sub_dirs, file_names in os.walk(config_dir): for file_name in file_names: file_basename = os.path.basename(file_name) match = re.match(r'[0-9]{1,10}\.settings', file_basename) if match is None: continue cur_seq_no = int(file_basename.split('.')[0]) file_path = os.path.join(config_dir, file_name) cur_time = os.path.getmtime(file_path) if latest_time is None or cur_time > latest_time: latest_time = cur_time latest_seq_no = cur_seq_no except: pass if latest_seq_no < 0: latest_seq_no = 0 SettingsSequenceNumber = latest_seq_no return SettingsSequenceNumber def run_get_output(cmd, chk_err = False, log_cmd = True): """ Mimic waagent mothod RunGetOutput in case waagent is not available Run shell command and return exit code and output """ if 'Utils.WAAgentUtil' in sys.modules: # WALinuxAgent-2.0.14 allows only 2 parameters for RunGetOutput # If checking the number of parameters fails, pass 2 try: sig = inspect.signature(waagent.RunGetOutput) params = sig.parameters waagent_params = len(params) except: try: spec = inspect.getargspec(waagent.RunGetOutput) params = spec.args waagent_params = len(params) except: waagent_params = 2 if waagent_params >= 3: exit_code, output = waagent.RunGetOutput(cmd, chk_err, log_cmd) else: exit_code, output = waagent.RunGetOutput(cmd, chk_err) else: try: output = subprocess.check_output(cmd, stderr = subprocess.STDOUT, shell = True) exit_code = 0 except subprocess.CalledProcessError as e: exit_code = e.returncode output = e.output # Python 2: encode unicode -> UTF-8 bytes (str). Python 3: decode bytes -> str. try: # Python 2 if isinstance(output, unicode): # type: ignore # noqa: F821 output = output.encode('utf-8', 'ignore') except NameError: # Python 3 if isinstance(output, (bytes, bytearray)): output = bytes(output).decode('utf-8', 'ignore') return exit_code, output.strip() def init_waagent_logger(): """ Initialize waagent logger If waagent has not been imported, catch the exception """ try: waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout', True) except Exception as e: print('Unable to initialize waagent log because of exception ' \ '{0}'.format(e)) def waagent_log_info(message): """ Log informational message, being cautious of possibility that waagent may not be imported """ if 'Utils.WAAgentUtil' in sys.modules: waagent.Log(message) else: print('Info: {0}'.format(message)) def waagent_log_error(message): """ Log error message, being cautious of possibility that waagent may not be imported """ if 'Utils.WAAgentUtil' in sys.modules: waagent.Error(message) else: print('Error: {0}'.format(message)) def hutil_log_info(message): """ Log informational message, being cautious of possibility that hutil may not be imported and configured """ if HUtilObject is not None: HUtilObject.log(message) else: print('Info: {0}'.format(message)) def hutil_log_error(message): """ Log error message, being cautious of possibility that hutil may not be imported and configured """ if HUtilObject is not None: HUtilObject.error(message) else: print('Error: {0}'.format(message)) def log_and_exit(operation, exit_code = GenericErrorCode, message = ''): """ Log the exit message and perform the exit """ if exit_code == 0: waagent_log_info(message) hutil_log_info(message) exit_status = 'success' else: waagent_log_error(message) hutil_log_error(message) exit_status = 'failed' if HUtilObject is not None: HUtilObject.do_exit(exit_code, operation, exit_status, str(exit_code), message) else: update_status_file(operation, str(exit_code), exit_status, message) sys.exit(exit_code) def validate_port_number(port_value, port_name): """ Validates that a port value is a valid integer within the range 1-65535. Args: port_value: The port value to validate (string) port_name: The name of the port for error messages (e.g., "fluent", "syslog") Returns: The validated port number as a string, or empty string if invalid """ if not port_value: return '' try: port_int = int(port_value.strip()) if port_int < 1 or port_int > 65535: hutil_log_error('Invalid {0} port number: {1}. Must be between 1-65535.'.format(port_name, port_int)) return '' return str(port_int) except ValueError: hutil_log_error('Invalid {0} port value: {1}. Must be an integer.'.format(port_name, port_value)) return '' # Exceptions # If these exceptions are expected to be caught by the main method, they # include an error_code field with an integer with which to exit from main class AzureMonitorAgentForLinuxException(Exception): """ Base exception class for all exceptions; as such, its error code is the basic error code traditionally returned in Linux: 1 """ error_code = GenericErrorCode def get_error_message(self, operation): """ Return a descriptive error message based on this type of exception """ return '{0} failed with exit code {1}'.format(operation, self.error_code) class ParameterMissingException(AzureMonitorAgentForLinuxException): """ There is a missing parameter for the AzureMonitorLinuxAgent Extension """ error_code = MissingorInvalidParameterErrorCode def get_error_message(self, operation): return '{0} failed due to a missing parameter: {1}'.format(operation, self.error_code) if __name__ == '__main__' : main() ================================================ FILE: AzureMonitorAgent/agent.version ================================================ AGENT_VERSION="1.12.0" AGENT_VERSION_DATE="" MDSD_DEB_PACKAGE_NAME="azuremonitoragent_1.12.0-build.master.89_x86_64.deb" MDSD_RPM_PACKAGE_NAME="azuremonitoragent-1.12.0-build.master.89_x86_64.rpm" ================================================ FILE: AzureMonitorAgent/ama_tst/AMA-Troubleshooting-Tool.md ================================================ # Troubleshooting Tool for Azure Monitor Linux Agent The following document provides quick information on the AMA Troubleshooting Tool, including how to use it and its checks. # Table of Contents - [Troubleshooter Basics](#troubleshooter-basics) - [Using the Troubleshooter](#using-the-troubleshooter) - [Requirements](#requirements) - [Scenarios Covered](#scenarios-covered) ## Troubleshooter Basics The Azure Monitor Linux Agent Troubleshooter is designed in order to help find and diagnose issues with the agent, as well as general health checks. At the current moment, the AMA TST can run checks to verify agent installation, connection, and general heartbeat, as well as collect AMA-related logs automatically from the affected Linux VM. In addition, more checks are being added regularly, to help increase the number of scenarios the AMA TST can catch. ## Using the Troubleshooter The AMA Linux Troubleshooter is automatically installed upon installation of AMA, and can be located and run by the following commands: 1. Go to the troubleshooter's installed location: `cd /var/lib/waagent/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-/ama_tst` 2. Run the troubleshooter: `sudo sh ama_troubleshooter.sh` If the troubleshooter isn't properly installed, or needs to be updated, the newest version can be downloaded and run by following the steps below. 1. Copy the troubleshooter bundle onto your machine: `wget https://github.com/Azure/azure-linux-extensions/raw/master/AzureMonitorAgent/ama_tst/ama_tst.tgz` 2. Unpack the bundle: `tar -xzvf ama_tst.tgz` 3. Run the troubleshooter: `sudo sh ama_troubleshooter.sh` ## Requirements The AMA Linux Troubleshooter requires Python 2.6+ installed on the machine, but will work with either Python2 or Python3. In addition, the following Python packages are required to run (all should be present on a default install of Python2 or Python3): | Python Package | Required for Python2? | Required for Python3? | | --- | --- | --- | | copy | **yes** | **yes** | | datetime | **yes** | **yes** | | json | **yes** | **yes** | | os | **yes** | **yes** | | platform | **yes** | **yes** | | re | **yes** | **yes** | | requests | no | **yes** | | shutil | **yes** | **yes** | | subprocess | **yes** | **yes** | | urllib | **yes** | no | | xml.dom.minidom | **yes** | **yes** | ## Scenarios Covered 1. Agent having installation issues * Supported OS / version * Available disk space * Package manager is available (dpkg/rpm) * Submodules are installed successfully * AMA installed properly * Syslog available (rsyslog/syslog-ng) * Using newest version of AMA * Syslog user generated successfully 2. Agent doesn't start, can't connect to Log Analytics * AMA parameters set up * AMA DCR created successfully * Connectivity to endpoints * Submodules started * IMDS/HIMDS metadata and MSI tokens available 3. Agent is unhealthy, heartbeat doesn't work properly * Submodule status * Parse error files 4. Agent has high CPU / memory usage * Check logrotate * Monitor CPU/memory usage in 5 minutes (interaction mode only) 5. Agent syslog collection doesn't work properly * Rsyslog / syslog-ng set up and running * Syslog configuration being pulled / used * Syslog socket is accessible 6. Agent custom log collection doesn't work properly * Custom log configuration being pulled / used * Log file paths is valid 7. Agent metrics collection doesn't work properly * Runs the metrics troubleshooter script * Produces `MdmDataCollectionOutput_*.tar.gz` for investigation 8. (A) Run all scenarios * Run through scenarios 1-7 in order 9. (L) Collect logs * Collects all of the logs needed to troubleshoot AMA in a zip file * Includes MDSD and AMACoreAgent environment variables ================================================ FILE: AzureMonitorAgent/ama_tst/__init__.py ================================================ # AMA troubleshooter modules ================================================ FILE: AzureMonitorAgent/ama_tst/ama_troubleshooter.sh ================================================ #!/usr/bin/env bash COMMAND="./modules/main.py" PYTHON="" TST_VERSION="1.7" # update when changes are made to TST ARG="$@" display_help() { echo "OPTIONS" echo " -A Run All Troubleshooting Tool checks" echo " -L Run Log Collector" echo " -v, --version Print Troubleshooting Tool version" } find_python() { local python_exec_command=$1 if command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" elif command -v python2 >/dev/null 2>&1 ; then eval ${python_exec_command}="python2" elif command -v /usr/libexec/platform-python >/dev/null 2>&1 ; then # If a user-installed python isn't available, check for a platform-python. This is typically only used in RHEL 8.0. echo "User-installed python not found. Using /usr/libexec/platform-python as the python interpreter." eval ${python_exec_command}="/usr/libexec/platform-python" fi } find_python PYTHON if [ -z "$PYTHON" ] # If python is not installed, we will fail the install with the following error, requiring cx to have python pre-installed then echo "No Python interpreter found, which is an AMA extension dependency. Please install Python 3, or Python 2 if the former is unavailable." >&2 exit 1 else echo "Python version being used is:" ${PYTHON} --version 2>&1 echo "" fi if [ "$1" = "--help" ] || [ "$1" = "-h" ] then display_help elif [ "$1" = "--version" ] || [ "$1" = "-v" ] then echo "AMA Troubleshooting Tool v.$TST_VERSION" else echo "Starting AMA Troubleshooting Tool v.$TST_VERSION..." echo "" PYTHONPATH=${PYTHONPATH} ${PYTHON} ${COMMAND} ${ARG} fi exit $? ================================================ FILE: AzureMonitorAgent/ama_tst/modules/__init__.py ================================================ # AMA troubleshooter modules ================================================ FILE: AzureMonitorAgent/ama_tst/modules/connect/__init__.py ================================================ # Connection check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/connect/check_endpts.py ================================================ import subprocess import traceback from error_codes import * from errors import error_info from helpers import geninfo_lookup, find_dce SSL_CMD = "echo | openssl s_client -connect {0}:443 -brief" CURL_CMD = "curl -s -S -k https://{0}/ping" GLOBAL_HANDLER_URL = "global.handler.control.monitor.azure.com" REGION_HANDLER_URL = "{0}.handler.control.monitor.azure.com" ODS_URL = "{0}.ods.opinsights.azure.com" ME_URL = "management.azure.com" ME_REGION_URL = "{0}.monitoring.azure.com" def _log_ssl_error(context, exception, show_traceback=True): """Helper function to log SSL errors cleanly""" print("{0}:".format(context)) print(" Type: {0}".format(type(exception).__name__)) print(" Message: {0}".format(str(exception))) # For CalledProcessError, show command details if isinstance(exception, subprocess.CalledProcessError): print(" Command: {0}".format(getattr(exception, 'cmd', 'Unknown'))) print(" Return code: {0}".format(getattr(exception, 'returncode', 'Unknown'))) if hasattr(exception, 'output') and exception.output: print(" Output: {0}".format(exception.output.strip())) # Show traceback if requested if show_traceback: print(" Traceback:") print(traceback.format_exc()) def check_endpt_ssl(ssl_cmd, endpoint): """ openssl connect to specific endpoint """ try: ssl_output = subprocess.check_output(ssl_cmd.format(endpoint), shell=True,\ stderr=subprocess.STDOUT, universal_newlines=True) ssl_output_lines = ssl_output.split('\n') (connected, verified) = (False, False) for line in ssl_output_lines: if (line == "CONNECTION ESTABLISHED"): connected = True continue if (line == "Verification: OK"): verified = True continue # If connection established but no explicit verification status in brief mode, # try a verification check to determine if SSL cert is valid if connected and not verified: try: # Use verify_return_error flag to test certificate verification verify_cmd = ssl_cmd.replace('-brief', '-verify_return_error -brief') verify_output = subprocess.check_output(verify_cmd.format(endpoint), shell=True,\ stderr=subprocess.STDOUT, universal_newlines=True) # If verify command succeeds (no exception), verification is OK if "CONNECTION ESTABLISHED" in verify_output: verified = True except subprocess.CalledProcessError as e: # Verification failed - certificate issues _log_ssl_error("SSL verification failed", e, show_traceback=False) verified = False except Exception as e: # Other error - assume verified if basic connection worked # This handles cases where verify_return_error isn't supported _log_ssl_error("SSL verification exception", e, show_traceback=True) verified = False return (connected, verified, ssl_output) except Exception as e: _log_ssl_error("SSL connection failed", e, show_traceback=True) return (False, False, str(e)) def check_internet_connect(): """ check general internet connectivity """ (connected_docs, verified_docs, e) = check_endpt_ssl(SSL_CMD, "docs.microsoft.com") if (connected_docs and verified_docs): return NO_ERROR elif (connected_docs and not verified_docs): error_info.append((SSL_CMD.format("docs.microsoft.com"),)) return WARN_INTERNET else: error_info.append((SSL_CMD.format("docs.microsoft.com"),)) return WARN_INTERNET_CONN def resolve_ip(endpoint): try: result = subprocess.call(['nslookup', endpoint], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) if not result == 0: return False, "nslookup {0}".format(endpoint) else: return (True, None) except Exception as e: return (False, e) def check_endpt_curl(endpoint): command = CURL_CMD.format(endpoint) try: # check proxy proxy = geninfo_lookup('MDSD_PROXY_ADDRESS') username = geninfo_lookup('MDSD_PROXY_USERNAME') if not proxy == None: command = command + ' -x {0}'.format(proxy) if not username == None: password = geninfo_lookup('MDSD_PROXY_PASSWORD') command = command + ' -U {0}:{1}'.format(username, password) output = subprocess.check_output(command, shell=True,\ stderr=subprocess.STDOUT, universal_newlines=True) if output == "Healthy": return NO_ERROR else: if proxy == None: error_info.append((endpoint, command, output)) return ERR_ENDPT else: error_info.append((endpoint, command, output)) return ERR_ENDPT_PROXY except Exception as e: error_info.append((endpoint, command, e)) return ERR_ENDPT def check_ama_endpts(): # compose URLs to check endpoints = [GLOBAL_HANDLER_URL] regions = geninfo_lookup('DCR_REGION') workspace_ids = geninfo_lookup('DCR_WORKSPACE_ID') if regions == None or workspace_ids == None: return ERR_INFO_MISSING for region in regions: endpoints.append(REGION_HANDLER_URL.format(region)) for id in workspace_ids: endpoints.append(ODS_URL.format(id)) if not geninfo_lookup('ME_REGION') == None: endpoints.append(ME_URL) for me_region in geninfo_lookup('ME_REGION'): endpoints.append(ME_REGION_URL.format(me_region)) # modify URLs if URL suffix is .us(Azure Government) or .cn(Azure China) url_suffix = geninfo_lookup('URL_SUFFIX') if not url_suffix == '.com': for endpoint in endpoints: endpoint.replace('.com', url_suffix) dce, e = find_dce() if e != None: error_info.append((e,)) return ERR_DCE for endpoint in dce: endpoints.append(endpoint) for endpoint in endpoints: # check if IP address can be resolved using nslookup resolved, e = resolve_ip(endpoint) if not resolved: error_info.append((endpoint,e)) return ERR_RESOLVE_IP # check ssl handshake command = SSL_CMD # skip openssl check with authenticated proxy if not geninfo_lookup('MDSD_PROXY_USERNAME') == None: return WARN_OPENSSL_PROXY proxy = geninfo_lookup('MDSD_PROXY_ADDRESS') if not proxy == None: proxy = proxy.replace('http://', '') command = command + ' -proxy {0}'.format(proxy) if not geninfo_lookup('SSL_CERT_DIR') == None: command = command + " -CApath " + geninfo_lookup('SSL_CERT_DIR') if not geninfo_lookup('SSL_CERT_FILE') == None: command = command + " -CAfile " + geninfo_lookup('SSL_CERT_FILE') (connected, verified, e) = check_endpt_ssl(command, endpoint) if not connected or not verified: error_info.append((endpoint, command.format(endpoint), e)) return ERR_ENDPT # check AMCS ping results if "handler.control.monitor" in endpoint: checked_curl = check_endpt_curl(endpoint) if checked_curl != NO_ERROR: return checked_curl return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/connect/check_imds.py ================================================ import subprocess import json from error_codes import * from errors import error_info from helpers import general_info, geninfo_lookup, is_arc_installed METADATA_CMD = 'curl -s -H Metadata:true --noproxy "*" "http://{0}/metadata/instance/compute?api-version=2020-06-01"' AZURE_IP = "169.254.169.254" ARC_IP = "127.0.0.1:40342" AZURE_TOKEN_CMD = "curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F' -H Metadata:true -s" ARC_TOKEN_CMD = 'ChallengeTokenPath=$(curl -s -D - -H Metadata:true "http://127.0.0.1:40342/metadata/identity/oauth2/token?api-version=2019-11-01&resource=https%3A%2F%2Fmanagement.azure.com"'\ '| grep Www-Authenticate | cut -d "=" -f 2 | tr -d "[:cntrl:]") ; ' \ 'ChallengeToken=$(cat $ChallengeTokenPath) ; ' \ 'curl -s -H Metadata:true -H "Authorization: Basic $ChallengeToken" "http://127.0.0.1:40342/metadata/identity/oauth2/token?api-version=2019-11-01&resource=https%3A%2F%2Fmanagement.azure.com"' def check_metadata(): global general_info type = "Azure" if is_arc_installed(): command = METADATA_CMD.format(ARC_IP) type = "Hybrid" else: command = METADATA_CMD.format(AZURE_IP) try: output = subprocess.check_output(command, shell=True,\ stderr=subprocess.STDOUT, universal_newlines=True) output_json = json.loads(output) attributes = ['azEnvironment', 'resourceId', 'location'] for attr in attributes: if not attr in output_json: error_info.append((type, command, output)) return ERR_IMDS_METADATA else: attr_result = output_json[attr] general_info[attr] = attr_result except Exception as e: error_info.append((type, command, e)) return ERR_IMDS_METADATA return NO_ERROR def check_token(): if is_arc_installed(): command = ARC_TOKEN_CMD else: command = AZURE_TOKEN_CMD try: # check AMA use UAI managed_identity = geninfo_lookup('MANAGED_IDENTITY') if not managed_identity == None: managed_identity = managed_identity.replace('mi_res_id#', 'mi_res_id=') command = command.replace('token?', 'token?{0}&'.format(managed_identity)) output = subprocess.check_output(command, shell=True,\ stderr=subprocess.STDOUT, universal_newlines=True) output_json = json.loads(output) if not 'access_token' in output_json: error_info.append((command, output)) return ERR_ACCESS_TOKEN except Exception as e: error_info.append((command, e)) return ERR_ACCESS_TOKEN return NO_ERROR def check_imds_api(): # check metadata checked_metadata = check_metadata() if not checked_metadata == NO_ERROR: return checked_metadata # check access token checked_token = check_token() if not checked_token == NO_ERROR: return checked_token return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/connect/connect.py ================================================ import os import json import subprocess import platform from error_codes import * from errors import error_info, is_error, print_errors from helpers import general_info, is_metrics_configured, find_dcr_workspace from .check_endpts import check_internet_connect, check_ama_endpts from .check_imds import check_imds_api try: FileNotFoundError except NameError: FileNotFoundError = IOError def check_parameters(): global general_info try: with open('/etc/default/azuremonitoragent', 'r') as fp: for line in fp: line = line.split('export')[1].strip() key = line.split('=')[0] value = line.split('=')[1] general_info[key] = value except (FileNotFoundError, AttributeError) as e: error_info.append((e,)) return ERR_AMA_PARAMETERS return NO_ERROR def check_workspace(): wkspc_id, wkspc_region, agent_settings, e = find_dcr_workspace() if e != None: error_info.append((e,)) return ERR_NO_DCR return NO_ERROR def check_subcomponents(): services = ['azuremonitoragent'] services.append('azuremonitor-coreagent') services.append('azuremonitor-agentlauncher') if is_metrics_configured(): services.append('metrics-sourcer') services.append('metrics-extension') for service in services: try: status = subprocess.check_output(['systemctl', 'status', service],\ universal_newlines=True, stderr=subprocess.STDOUT) status_lines = status.split('\n') for line in status_lines: line = line.strip() if line.startswith('Active:'): if not line.split()[1] == 'active': error_info.append((service, status)) return ERR_SUBCOMPONENT_STATUS except subprocess.CalledProcessError as e: error_info.append((e,)) return ERR_CHECK_STATUS return NO_ERROR def check_connection(interactive, err_codes=True, prev_success=NO_ERROR): print("CHECKING CONNECTION...") success = prev_success # check /etc/default/azuremonitoragent file print("Checking AMA parameters in /etc/default/azuremonitoragent...") checked_parameters = check_parameters() if (is_error(checked_parameters)): return print_errors(checked_parameters) else: success = print_errors(checked_parameters) # check DCR print("Checking DCR...") checked_workspace = check_workspace() if (is_error(checked_workspace)): return print_errors(checked_workspace) else: success = print_errors(checked_workspace) # check general internet connectivity print("Checking if machine is connected to the internet...") checked_internet_connect = check_internet_connect() if (is_error(checked_internet_connect)): return print_errors(checked_internet_connect) else: success = print_errors(checked_internet_connect) # check if AMA endpoints connected print("Checking if machine can connect to Azure Monitor control-plane and data ingestion endpoints...") checked_ama_endpts = check_ama_endpts() if (is_error(checked_ama_endpts)): return print_errors(checked_ama_endpts) else: success = print_errors(checked_ama_endpts) # check if subcomponents are active (e.g. mdsd, telegraf, etc) print("Checking if subcomponents have been started...") checked_subcomponents = check_subcomponents() if (is_error(checked_subcomponents)): return print_errors(checked_subcomponents) else: success = print_errors(checked_subcomponents) print("Checking if IMDS metadata and MSI tokens are available...") checked_imds_api = check_imds_api() if (is_error(checked_imds_api)): return print_errors(checked_imds_api) else: success = print_errors(checked_imds_api) return success ================================================ FILE: AzureMonitorAgent/ama_tst/modules/custom_logs/__init__.py ================================================ # Custom logs check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/custom_logs/check_clconf.py ================================================ import os from error_codes import * from errors import error_info from helpers import general_info, geninfo_lookup, run_cmd_output CLCONF_PATH = "/etc/opt/microsoft/azuremonitoragent/config-cache/fluentbit/td-agent.conf" def check_customlog_input(): cl_input = geninfo_lookup('CL_INPUT') if (cl_input == None or len(cl_input) == 0): error_info.append(("No custom logs file path",)) return ERR_CL_INPUT # cl_input is a list, not a dictionary - iterate over the paths directly for path in cl_input: # Skip malformed entries that don't look like valid file paths if not path or not path.startswith('/'): continue try: check_path = run_cmd_output('ls {0}'.format(path)).strip() if check_path.endswith('No such file or directory'): error_info.append((check_path,)) return ERR_CL_INPUT except Exception as e: error_info.append((e,)) return ERR_CL_INPUT return NO_ERROR def check_customlog_conf(): global general_info # verify td-agent.conf exists / not empty if (not os.path.isfile(CLCONF_PATH)): error_info.append(('file', CLCONF_PATH)) return ERR_FILE_MISSING if (os.stat(CLCONF_PATH).st_size == 0): error_info.append((CLCONF_PATH,)) return ERR_FILE_EMPTY general_info['CL_INPUT'] = [] try: with open(CLCONF_PATH, 'r') as cl_file: cl_lines = cl_file.readlines() for cl_line in cl_lines: if (cl_line.strip().startswith('log_file')): cl_log_file = cl_line.strip().split('log_file')[1] general_info['CL_LOG'] = cl_log_file # Only match exact "Path" lines (not "Path_Key" or other variants) if (cl_line.strip().startswith('Path ') or cl_line.strip().startswith('Path\t')): # Extract the path value after the whitespace parts = cl_line.strip().split(None, 1) # Split on any whitespace, max 1 split if len(parts) > 1: cl_input_path = parts[1].strip() # Only add valid file paths (should start with /) if cl_input_path.startswith('/'): general_info['CL_INPUT'].append(cl_input_path) except Exception as e: error_info.append((e,)) return ERR_CL_CONF print('cl_input value: {0}'.format(general_info['CL_INPUT'])) return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/custom_logs/custom_logs.py ================================================ from error_codes import * from errors import is_error, get_input, print_errors from .check_clconf import check_customlog_conf, check_customlog_input def check_custom_logs(interactive, prev_success=NO_ERROR): if (interactive): using_cl = get_input("Are you currently using custom logs? (y/n)",\ (lambda x : x.lower() in ['y','yes','n','no']),\ "Please type either 'y'/'yes' or 'n'/'no' to proceed.") # not using custom logs if (using_cl in ['n','no']): print("Continuing on with the rest of the troubleshooter...") print("================================================================================") return prev_success # using custom logs else: print("Continuing on with troubleshooter...") print("--------------------------------------------------------------------------------") print("CHECKING FOR CUSTOM LOG ISSUES...") success = prev_success # check td-agent.conf print("Checking for custom logs configuration files...") checked_clconf = check_customlog_conf() if (is_error(checked_clconf)): return print_errors(checked_clconf) else: success = print_errors(checked_clconf) # check custom logs input file path print("Checking for custom logs input files...") checked_customlog_input = check_customlog_input() if (is_error(checked_customlog_input)): return print_errors(checked_customlog_input) else: success = print_errors(checked_customlog_input) return success ================================================ FILE: AzureMonitorAgent/ama_tst/modules/error_codes.py ================================================ # # General Errors NO_ERROR = 0 USER_EXIT = 1 ERR_SUDO_PERMS = 100 ERR_FOUND = 101 # Warnings WARN_INTERNET_CONN = 10 WARN_INTERNET = 11 WARN_OPENSSL_PROXY = 12 WARN_MDSD_ERR_FILE = 13 WARN_RESTART_LOOP = 14 # Installation Errors ERR_BITS = 102 ERR_OS_VER = 103 ERR_OS = 104 ERR_FINDING_OS = 105 ERR_FREE_SPACE = 106 ERR_PKG_MANAGER = 107 ERR_SUBCOMPONENT_INSTALL = 108 ERR_MULTIPLE_AMA = 109 ERR_AMA_INSTALL = 110 ERR_LOG_DAEMON = 111 ERR_SYSLOG_USER = 112 ERR_OLD_AMA_VER = 113 ERR_GETTING_AMA_VER = 114 ERR_COUNTER_FILE_MISSING = 115 # Onboarding Errors ERR_AMA_PARAMETERS = 200 ERR_NO_DCR = 201 ERR_INFO_MISSING = 202 ERR_ENDPT = 203 ERR_SUBCOMPONENT_STATUS = 204 ERR_CHECK_STATUS = 205 ERR_RESOLVE_IP = 206 ERR_IMDS_METADATA = 207 ERR_ACCESS_TOKEN = 208 ERR_ENDPT_PROXY = 209 ERR_DCE = 210 # CPU/Memory Errors ERR_FILE_MISSING = 300 ERR_LOGROTATE_SIZE = 301 WARN_LOGROTATE = 302 ERR_FILE_ACCESS = 303 # Syslog Errors ERR_SYSLOG = 400 ERR_SERVICE_STATUS = 401 ERR_FILE_EMPTY = 402 ERR_CONF_FILE_PERMISSION = 403 # Custom Logs Errors ERR_CL_CONF = 500 ERR_CL_INPUT = 501 ================================================ FILE: AzureMonitorAgent/ama_tst/modules/errors.py ================================================ import copy import subprocess from error_codes import * # backwards compatible input() function for Python 2 vs 3 try: input = raw_input except NameError: pass # error info edited when error occurs error_info = [] # list of all errors called when script ran err_summary = [] # set of all errors which are actually warnings warnings = set([WARN_INTERNET_CONN, WARN_INTERNET, WARN_OPENSSL_PROXY, WARN_MDSD_ERR_FILE, WARN_RESTART_LOOP, WARN_LOGROTATE]) # dictionary correlating error codes to error messages error_messages = { WARN_INTERNET : "SSL connection couldn't be verified. Please run the command below for more information on this warning:\n"\ "\n $ {0}\n", WARN_INTERNET_CONN : "Machine is not connected to the internet: openssl command failed. "\ "Please run the command below for more information on the failure:\n"\ "\n $ {0}\n", ERR_SUDO_PERMS : "Couldn't access {0} due to inadequate permissions. Please run the troubleshooter "\ "as root in order to allow access.", ERR_FOUND : "Please go through the output above to find the errors caught by the troubleshooter.", ERR_BITS : "Couldn't get AMA if CPU is not 64-bit.", ERR_OS_VER : "This version of {0} ({1}) is not supported. Please download {2}. To see all "\ "supported Operating Systems, please go to:\n"\ "\n https://docs.microsoft.com/en-us/azure/azure-monitor/agents/agents-overview#linux\n", ERR_OS : "{0} is not a supported Operating System. To see all supported Operating "\ "Systems, please go to:\n"\ "\n https://docs.microsoft.com/en-us/azure/azure-monitor/agents/agents-overview#linux\n", ERR_FINDING_OS : "Coudln't determine Operating System. To see all supported Operating "\ "Systems, please go to:\n"\ "\n https://docs.microsoft.com/en-us/azure/azure-monitor/agents/agents-overview#linux\n" \ "\n\nError Details: \n{0}", ERR_FREE_SPACE : "There isn't enough space in directory {0} to install AMA - there needs to be at least 500MB free, "\ "but {0} has {1}MB free. Please free up some space and try installing again.", ERR_PKG_MANAGER : "This system does not have a supported package manager. Please install 'dpkg' or 'rpm' "\ "and run this troubleshooter again.", ERR_MULTIPLE_AMA : "There is more than one instance of AMA installed, please remove the extra AMA packages.", ERR_AMA_INSTALL : "AMA package isn't installed correctly.\n\nError Details: \n{0}", ERR_SUBCOMPONENT_INSTALL : "Subcomponents(s) {0} not installed correctly.", ERR_LOG_DAEMON : "No logging daemon found. Please install rsyslog or syslog-ng.", ERR_SYSLOG_USER : "Syslog user is not created successfully.", ERR_OLD_AMA_VER : "You are currently running AMA Version {0}. This troubleshooter only "\ "supports versions 1.9 and newer. Please upgrade to the newest version. You can find "\ "more information at the link below:\n"\ "\n https://docs.microsoft.com/en-us/azure/azure-monitor/agents/azure-monitor-agent-manage\n", ERR_GETTING_AMA_VER : "Couldn't get most current released version of AMA.\n\nError Details: \n{0}", ERR_COUNTER_FILE_MISSING : "metricCounters.json file is not found. Please check your perf counters configuration.", ERR_AMA_PARAMETERS : "Couldn't read and parse AMA configuration in /etc/default/azuremonitoragent.\n\nError Details:\n{0}", ERR_NO_DCR : "Couldn't parse DCR information on this VM. Please check your DCR configuration.\n\nError Details:{0}", ERR_INFO_MISSING: "NO DCR workspace id or region is found. Please check if DCR is configured correctly and match the information in"\ "/etc/opt/microsoft/azuremonitoragent/config-cache/configchunks.*.json", ERR_ENDPT : "Machine couldn't connect to {0}: curl/openssl command failed. "\ "\n\nError Details:\n $ {1} \n\n{2}", ERR_SUBCOMPONENT_STATUS : "Subcomponent {0} has not been started. Status details: {1}", ERR_CHECK_STATUS : "Couldn't get the status of subcomponents.\n\nError Details:{0}", ERR_RESOLVE_IP : "The endpoint {0} cannot be resolved. Please run the command below for more information on the failure:\n\n $ {1}", ERR_IMDS_METADATA : "Couldn't access {0} Instance Metadata Service when executing command\n $ {1}\n\nError Details:\n{2}", ERR_ACCESS_TOKEN : "Couldn't use managed identities to acquire an access token when executing command\n $ {0}\n\nError Details:\n{1}", ERR_ENDPT_PROXY : "Machine couldn't connect to {0} with proxy: curl/openssl command failed. Please check your proxy configuration."\ "\n\nError Details:\n $ {1} \n\n{2}", ERR_DCE : "Couldn't parse DCE information on this VM. Please check your DCE configuration.\n\nError Details:{0}", WARN_OPENSSL_PROXY : "Skip SSL handshake checks because AMA is configured with authenticated proxy.", WARN_MDSD_ERR_FILE : "Found errors in log file {0}, displaying last few lines of error messages:\n {1}", WARN_RESTART_LOOP : "Subcomponents might be in a restart loop. Details:\n\n{0}", ERR_FILE_MISSING : "{0} {1} doesn't exist.", ERR_LOGROTATE_SIZE : "Logrotate size limit for log {0} has invalid formatting. Please see {1} for more "\ "information.", WARN_LOGROTATE : "Logrotate isn't rotating log {0}: its current size is {1}, and it should have "\ "been rotated at {2}. Please see {3} for more information.", ERR_FILE_ACCESS : "Couldn't access or run {0} due to the following reason: {1}.", ERR_SYSLOG : "Couldn't find either 'rsyslog' or 'syslog-ng' on machine. Please download "\ "one of the two services and try again.", ERR_SERVICE_STATUS : "{0} current status is the following: '{1}'. Please check the status of {0} "\ "using {2} for more information.", ERR_FILE_EMPTY : "File {0} is empty.", ERR_CONF_FILE_PERMISSION : "{0} {1} is not accesible by syslog user. Please grant syslog user {2} permission.", ERR_CL_CONF : "Custom logs configuration file /etc/opt/microsoft/azuremonitoragent/config-cache/fluentbit/td-agent.conf "\ "cannot be parsed.\n\nError Details:\n{0}", ERR_CL_INPUT : "Custom logs input file path is either empty or invalid. Please check your input path in "\ "/etc/opt/microsoft/azuremonitoragent/config-cache/fluentbit/td-agent.conf.\n\nError Details:\n{0}" } # check if either has no error or is warning def is_error(err_code): not_errs = warnings.copy() not_errs.add(NO_ERROR) return (err_code not in not_errs) # for getting inputs from the user def get_input(question, check_ans, no_fit): answer = input(" {0}: ".format(question)) while (not check_ans(answer.lower())): print("Unclear input. {0}".format(no_fit)) answer = input(" {0}: ".format(question)) return answer def print_errors(err_code): not_errors = set([NO_ERROR, USER_EXIT]) if (err_code in not_errors): return err_code warning = False if (err_code in warnings): warning = True err_string = error_messages[err_code] # no formatting if (error_info == []): err_string = "ERROR FOUND: {0}".format(err_string) err_summary.append(err_string) # needs input else: while (len(error_info) > 0): tup = error_info.pop(0) temp_err_string = err_string.format(*tup) if (warning): final_err_string = "WARNING FOUND: {0}".format(temp_err_string) else: final_err_string = "ERROR FOUND: {0}".format(temp_err_string) err_summary.append(final_err_string) if (warning): print("WARNING(S) FOUND.") return NO_ERROR else: print("ERROR(S) FOUND.") return ERR_FOUND ================================================ FILE: AzureMonitorAgent/ama_tst/modules/general_health/__init__.py ================================================ # General health check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/general_health/check_status.py ================================================ import subprocess import re import platform from error_codes import * from errors import error_info from helpers import run_cmd_output, get_input, is_metrics_configured def check_restart_status(interactive): """ check if the subcomponents restart in a given time interval """ subcomponents = {'azuremonitoragent': 'azuremonitoragent'} subcomponents['azuremonitor-agentlauncher'] = 'agentlauncher' subcomponents['azuremonitor-coreagent'] = 'amacoreagent' if is_metrics_configured(): subcomponents['metrics-extension'] = 'MetricsExtension' subcomponents['metrics-sourcer'] = 'Telegraf' restart_logs = "" start = "yesterday" end = "now" since = "--since={0}".format(start) until = "--until={0}".format(end) if interactive: print("--------------------------------------------------------------------------------") print("Please enter a certain time range that you want to filter logs (default time range: from yesterday to now):\n") print("(e.g. Since: ) or ") start_input = get_input("Since: ") end_input = get_input("Until: ") print("--------------------------------------------------------------------------------") if start_input != "": since = '--since=\"{0}\"'.format(start_input) start = start_input if end_input != "": until = '--until=\"{0}\"'.format(end_input) end = end_input for key in subcomponents.keys(): cmd = 'journalctl -n 100 --no-pager -u {0} {1} {2}'.format(key, since, until) output = run_cmd_output(cmd) lines = output.split('\n') process_logs = {} for line in lines: match = re.findall(".*{0}\[.*\].*".format(subcomponents[key]), line) if len(match) == 0: continue log = match[0] pid = log.split('[')[1].split(']')[0] if pid not in process_logs: process_logs[pid] = log # add to warning if restart more than 10 times recently if len(process_logs) > 10: logs = '\n'.join(process_logs.values()) restart_logs = restart_logs + "Possible restart loop in {0} detected ({1} restarts from {2} to {3}):\n{4}".format(key, len(process_logs), start, end, logs) restart_logs = restart_logs + "\n--------------------------------------------------------------------------------\n" if restart_logs: error_info.append((restart_logs,)) return WARN_RESTART_LOOP return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/general_health/general_health.py ================================================ import os from error_codes import * from errors import error_info, is_error, print_errors from .check_status import check_restart_status ERR_FILE_PATH = "/var/opt/microsoft/azuremonitoragent/log/mdsd.err" def check_err_file(): """ output mdsd.err contents if the file is not empty """ tail_size = -50 pattern = ' [DAEMON] ' err_logs = [] with open(ERR_FILE_PATH) as f: lines = f.readlines(10000) lines = lines[tail_size:] for line in lines: line = line.rstrip('\n') # skip empty lines, daemon start/exit logs if line == '': continue elif pattern in line: continue else: err_logs.append(line) if len(err_logs) > 0: err_logs_str = '\n' + ('\n'.join(err_logs)) error_info.append((ERR_FILE_PATH, err_logs_str)) return WARN_MDSD_ERR_FILE return NO_ERROR def check_general_health(interactive, err_codes=True, prev_success=NO_ERROR): print("CHECKING IF THE AGENT IS HEALTHY...") success = prev_success print("Checking status of subcomponents") checked_restart_status = check_restart_status(interactive) if (is_error(checked_restart_status)): return print_errors(checked_restart_status) else: success = print_errors(checked_restart_status) print("Checking mdsd.err file") checked_err_file = check_err_file() if (is_error(checked_err_file)): return print_errors(checked_err_file) else: success = print_errors(checked_err_file) print("============================================") return success ================================================ FILE: AzureMonitorAgent/ama_tst/modules/helpers.py ================================================ import os import json import platform import subprocess from errors import error_info from error_codes import * CONFIG_DIR = '/etc/opt/microsoft/azuremonitoragent/config-cache/configchunks' METRICS_FILE = "/etc/opt/microsoft/azuremonitoragent/config-cache/metricCounters.json" # backwards compatible input() function for Python 2 vs 3 try: input = raw_input except NameError: pass try: FileNotFoundError except NameError: FileNotFoundError = IOError # backwards compatible JSONDecodeError for Python 2 vs 3 try: json.JSONDecodeError except AttributeError: # Python 2 doesn't have json.JSONDecodeError, use ValueError instead json.JSONDecodeError = ValueError # backwards compatible devnull variable for Python 3.3 vs earlier try: DEVNULL = subprocess.DEVNULL except: DEVNULL = open(os.devnull) general_info = dict() def geninfo_lookup(key): try: val = general_info[key] except KeyError: return None return val def get_input(question, check_ans=None, no_fit=None): if check_ans == None and no_fit == None: return input(question) answer = input(" {0}: ".format(question)) while (not check_ans(answer.lower())): print("Unclear input. {0}".format(no_fit)) answer = input(" {0}: ".format(question)) return answer def is_arc_installed(): """ Check if this is an Arc machine """ # Using systemctl to check this since Arc only supports VMs that have systemd check_arc = os.system('systemctl status himdsd 1>/dev/null 2>&1') return check_arc == 0 def find_vm_bits(): cpu_info = subprocess.check_output(['lscpu'], universal_newlines=True) cpu_opmodes = (cpu_info.split('\n'))[1] cpu_bits = cpu_opmodes[-6:] return cpu_bits def find_vm_distro(): """ Finds the Linux Distribution this vm is running on. """ vm_dist = vm_id = vm_ver = None parse_manually = False try: vm_dist, vm_ver, vm_id = platform.linux_distribution() except AttributeError: try: vm_dist, vm_ver, vm_id = platform.dist() except AttributeError: # Falling back to /etc/os-release distribution parsing pass # Some python versions *IF BUILT LOCALLY* (ex 3.5) give string responses (ex. 'bullseye/sid') to platform.dist() function # This causes exception in the method below. Thus adding a check to switch to manual parsing in this case try: temp_vm_ver = int(vm_ver.split('.')[0]) except: parse_manually = True if (not vm_dist and not vm_ver) or parse_manually: # SLES 15 and others try: with open('/etc/os-release', 'r') as fp: for line in fp: if line.startswith('ID='): vm_dist = line.split('=')[1] vm_dist = vm_dist.split('-')[0] vm_dist = vm_dist.replace('\"', '').replace('\n', '') vm_dist = vm_dist.lower() elif line.startswith('VERSION_ID='): vm_ver = line.split('=')[1] vm_ver = vm_ver.replace('\"', '').replace('\n', '') vm_ver = vm_ver.lower() except (FileNotFoundError, AttributeError) as e: # indeterminate OS return (None, None, e) return (vm_dist, vm_ver, None) def find_package_manager(): global general_info """ Checks which package manager is on the system """ pkg_manager = "" # check if debian system if (os.path.isfile("/etc/debian_version")): try: subprocess.check_output("command -v dpkg", shell=True) pkg_manager = "dpkg" except subprocess.CalledProcessError: pass # check if redhat system elif (os.path.isfile("/etc/redhat_version")): try: subprocess.check_output("command -v rpm", shell=True) pkg_manager = "rpm" except subprocess.CalledProcessError: pass # likely SUSE or modified VM, just check dpkg and rpm if (pkg_manager == ""): try: subprocess.check_output("command -v dpkg", shell=True) pkg_manager = "dpkg" except subprocess.CalledProcessError: try: subprocess.check_output("command -v rpm", shell=True) pkg_manager = "rpm" except subprocess.CalledProcessError: pass general_info['PKG_MANAGER'] = pkg_manager return pkg_manager def get_package_version(pkg): pkg_mngr = geninfo_lookup('PKG_MANAGER') # dpkg if (pkg_mngr == 'dpkg'): return get_dpkg_pkg_version(pkg) # rpm elif (pkg_mngr == 'rpm'): return get_rpm_pkg_version(pkg) else: return (None, None) # Package Info def get_dpkg_pkg_version(pkg): try: dpkg_info = subprocess.check_output(['dpkg', '-s', pkg], universal_newlines=True,\ stderr=subprocess.STDOUT) dpkg_lines = dpkg_info.split('\n') for line in dpkg_lines: if (line.startswith('Package: ') and not line.endswith(pkg)): # wrong package return (None, None) if (line.startswith('Status: ') and not line.endswith('installed')): # not properly installed return (None, None) if (line.startswith('Version: ')): version = (line.split())[-1] return (version, None) return (None, None) except subprocess.CalledProcessError as e: return (None, e.output) def get_rpm_pkg_version(pkg): try: rpm_info = subprocess.check_output(['rpm', '-qi', pkg], universal_newlines=True,\ stderr=subprocess.STDOUT) if ("package {0} is not installed".format(pkg) in rpm_info): # didn't find package return (None, None) rpm_lines = rpm_info.split('\n') for line in rpm_lines: parsed_line = line.split() if (parsed_line[0] == 'Name'): # ['Name', ':', name] name = parsed_line[2] if (name != pkg): # wrong package return (None, None) if (parsed_line[0] == 'Version'): # ['Version', ':', version] version = parsed_line[2] return (version, None) return (None, None) except subprocess.CalledProcessError as e: return (None, e.output) def find_ama_version(): """ Gets a list of all AMA versions installed on the VM """ try: config_dirs = filter((lambda x : x.startswith("Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-")), os.listdir("/var/lib/waagent")) ama_vers = list(map((lambda x : (x.split('-'))[-1]), config_dirs)) except FileNotFoundError as e: return (None, e) return (ama_vers, None) def check_ama_installed(ama_vers): """ Checks to verify AMA is installed and only has one version installed at a time """ ama_exists = ((ama_vers != None) and (len(ama_vers) > 0)) ama_unique = (ama_exists and (len(ama_vers) == 1)) return (ama_exists, ama_unique) def run_cmd_output(cmd): """ Common logic to run any command and check/get its output for further use """ try: out = subprocess.check_output(cmd, shell=True, universal_newlines=True, stderr=subprocess.STDOUT) return out except subprocess.CalledProcessError as e: return (e.output) def find_dcr_workspace(): """ Parse DCR configuration files to find workspace IDs and regions. """ global general_info if 'DCR_WORKSPACE_ID' in general_info and 'DCR_REGION' in general_info: return (general_info['DCR_WORKSPACE_ID'], general_info['DCR_REGION'], None) dcr_workspace = set() dcr_region = set() me_region = set() agent_settings = {} general_info['URL_SUFFIX'] = '.com' try: for file in os.listdir(CONFIG_DIR): file_path = CONFIG_DIR + "/" + file with open(file_path) as f: result = json.load(f) # Check if this is an AgentSettings DCR - parse its settings if 'kind' in result and result['kind'] == 'AgentSettings' and 'channels' not in result: if 'settings' in result: settings_str = result['settings'] try: # The settings field is a JSON string, so parse it if isinstance(settings_str, str): settings_list = json.loads(settings_str) else: settings_list = settings_str # Process each setting for setting in settings_list: name = setting['name'] value = setting['value'] if name: agent_settings[name] = value except (json.JSONDecodeError, TypeError) as e: # If parsing fails, skip this AgentSettings DCR print("Error parsing settings key in AgentSettings DCR") continue channels = result['channels'] for channel in channels: if channel['protocol'] == 'ods': # parse dcr workspace id endpoint_url = channel['endpoint'] workspace_id = endpoint_url.split('https://')[1].split('.ods')[0] dcr_workspace.add(workspace_id) # parse dcr region token_endpoint_uri = channel['tokenEndpointUri'] region = token_endpoint_uri.split('Location=')[1].split('&')[0] dcr_region.add(region) # parse url suffix if '.us' in endpoint_url: general_info['URL_SUFFIX'] = '.us' if '.cn' in endpoint_url: general_info['URL_SUFFIX'] = '.cn' if channel['protocol'] == 'me': # parse ME region endpoint_url = channel['endpoint'] region = endpoint_url.split('https://')[1].split('.monitoring')[0] me_region.add(region) except Exception as e: return (None, None, None, e) general_info['DCR_WORKSPACE_ID'] = dcr_workspace general_info['DCR_REGION'] = dcr_region general_info['ME_REGION'] = me_region return (dcr_workspace, dcr_region, agent_settings, None) def find_dce(): """ Parse DCR configuration files to find Data Collection Endpoints (DCE). """ global general_info dce = set() try: for file in os.listdir(CONFIG_DIR): file_path = CONFIG_DIR + "/" + file with open(file_path) as f: result = json.load(f) # Check if this is an AgentSettings DCR, if so skip it if 'kind' in result and result['kind'] == 'AgentSettings' and 'channels' not in result: continue channels = result['channels'] for channel in channels: if channel['protocol'] == 'gig': # parse dce logs ingestion endpoint ingest_endpoint_url = channel['endpointUriTemplate'] ingest_endpoint = ingest_endpoint_url.split('https://')[1].split('/')[0] dce.add(ingest_endpoint) # parse dce configuration access endpoint configuration_endpoint_url = channel['tokenEndpointUri'] configuration_endpoint = configuration_endpoint_url.split('https://')[1].split('/')[0] dce.add(configuration_endpoint) except Exception as e: return (None, None, e) general_info['DCE'] = dce return (dce, None) def is_metrics_configured(): global general_info if 'metrics' in general_info: return general_info['metrics'] with open(METRICS_FILE) as f: output = f.read(2) if output != '[]': general_info['metrics'] = True else: general_info['metrics'] = False return general_info['metrics'] ================================================ FILE: AzureMonitorAgent/ama_tst/modules/high_cpu_mem/__init__.py ================================================ # CPU/memory check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/high_cpu_mem/check_logrot.py ================================================ import errno import os import re from error_codes import * from errors import error_info LR_CONFIG_PATH = "/etc/logrotate.d/azuremonitoragent" def hr2bytes(hr_size): if (hr_size.isdigit()): return int(hr_size) hr_digits = hr_size[:-1] hr_units = hr_size[-1] if (hr_digits.isdigit()): # kilobytes if (hr_units == 'k'): return int(hr_digits) * 1000 # megabytes elif (hr_units == 'M'): return int(hr_digits) * 1000000 # gigabytes elif (hr_units == 'G'): return int(hr_digits) * 1000000000 # wrong formatting return None def check_size_config(logrotate_configs): for k in list(logrotate_configs.keys()): # grab size limit if exists size_config = next((x for x in logrotate_configs[k] if x.startswith('size ')), None) if (size_config == None): continue size_limit = hr2bytes(size_config.split()[1]) if (size_limit == None): error_info.append((k, LR_CONFIG_PATH)) return ERR_LOGROTATE_SIZE # get current size of file try: size_curr = os.path.getsize(k) if (size_curr > size_limit): error_info.append((k, size_curr, size_limit, LR_CONFIG_PATH)) return WARN_LOGROTATE # couldn't get current size of file except os.error as e: if (e.errno == errno.EACCES): error_info.append((k,)) return ERR_SUDO_PERMS elif (e.errno == errno.ENOENT): if ('missingok' in logrotate_configs[k]): continue else: error_info.append(('log file', k)) return ERR_FILE_MISSING else: error_info.append((k, e.strerror)) return ERR_FILE_ACCESS return NO_ERROR def check_log_rotation(): # check logrotate config file exists if (not os.path.isfile(LR_CONFIG_PATH)): error_info.append(('logrotate config file', LR_CONFIG_PATH)) return ERR_FILE_MISSING # go through logrotate config file logrotate_configs = dict() with open(LR_CONFIG_PATH, 'r') as f: lr_lines = f.readlines() in_file = None for lr_line in lr_lines: lr_line = lr_line.rstrip('\n') # start of log rotation config lr_start = re.match("^/(\S+)", lr_line) if (lr_start != None): in_file = lr_start.group() logrotate_configs[in_file] = set() continue # log rotation config info elif (in_file != None): logrotate_configs[in_file].add(lr_line.lstrip()) continue # end of log rotation config elif (lr_line == '}'): in_file = None continue # check size rotation working checked_size_config = check_size_config(logrotate_configs) if (checked_size_config != NO_ERROR): return checked_size_config return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/high_cpu_mem/check_usage.py ================================================ import time import subprocess from error_codes import * from errors import error_info from helpers import get_input, run_cmd_output def find_mdsd_pid(): try: status = run_cmd_output('systemctl status azuremonitoragent') status_lines = status.split('\n') for line in status_lines: line = line.strip() if line.startswith('Main PID:'): pid = line.split()[2] return (pid, None) except subprocess.CalledProcessError as e: return (None, e) def check_usage(interactive): if interactive: print("Checking CPU/memory usage of AMA subcomponents...") result = get_input("Do you want to monitor the CPU/memory usage of AMA in 5 minutes? (YES/no)", \ (lambda x : x.lower() in ['y','yes','n','no', '']),\ "Please enter 'y'/'yes' to run this check, 'n'/'no' to skip this check. \n") if result.lower() in ['n', 'no']: return NO_ERROR mdsd_pid, e = find_mdsd_pid() if e != None: error_info.append((e,)) return ERR_CHECK_STATUS cmd = "top -b -n1 | grep {0}".format(mdsd_pid) cpu = [] mem = [] # run 5 minutes to collect min/max/avg usage for i in range(0, 30): output = run_cmd_output(cmd) values = list(filter(None, output.strip().split(" "))) cpu.append(float(values[8])) mem.append(float(values[9])) time.sleep(10) max_cpu = max(cpu) min_cpu = min(cpu) avg_cpu = sum(cpu)/len(cpu) max_mem = max(mem) min_mem = min(mem) avg_mem = sum(mem)/len(mem) print("--------------------------------------------------------------------------------") print("CPU usage in the last 5 minutes (%CPU)") print("Max: ", max_cpu, "Min: ", min_cpu, "Avg: ", "%.1f" % avg_cpu) print("Memory usage in the last 5 minutes (%MEM)") print("Max: ", max_mem, "Min: ", min_mem, "Avg: ", "%.1f" % avg_mem) return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/high_cpu_mem/high_cpu_mem.py ================================================ from error_codes import * from errors import is_error, print_errors from .check_logrot import check_log_rotation from .check_usage import check_usage def check_high_cpu_memory(interactive, prev_success=NO_ERROR): print("CHECKING FOR HIGH CPU / MEMORY USAGE...") success = prev_success # check log rotation print("Checking if log rotation is working correctly...") checked_logrot = check_log_rotation() if (is_error(checked_logrot)): return print_errors(checked_logrot) else: success = print_errors(checked_logrot) # check AMA CPU/memory usage checked_usage = check_usage(interactive) if (is_error(checked_usage)): return print_errors(checked_usage) else: success = print_errors(checked_usage) return success ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/__init__.py ================================================ # Install check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/check_ama.py ================================================ import re import sys import socket import xml.dom.minidom if sys.version_info[0] == 3: import urllib.request as urllib import urllib.error as urlerror elif sys.version_info[0] == 2: import urllib2 as urllib import urllib2 as urlerror try: import requests except ImportError: pass from error_codes import * from errors import error_info, get_input from helpers import get_package_version from connect.check_endpts import check_internet_connect AMA_URL = 'https://docs.microsoft.com/en-us/azure/azure-monitor/agents/azure-monitor-agent-extension-versions' # Timeout for fetching latest AMA version (in seconds) AMA_VERSION_FETCH_TIMEOUT = 60 def get_latest_ama_version(curr_version): # python2 and python3 compatible # Set timeout to prevent hanging timeout = AMA_VERSION_FETCH_TIMEOUT try: if sys.version_info[0] == 3: # Python 3 - try urllib first, then requests as fallback try: r = urllib.urlopen(AMA_URL, timeout=timeout).read() except AttributeError: # If urllib doesn't work, try requests r = requests.get(AMA_URL, timeout=timeout).text else: # Python 2 - use urllib2 which supports timeout r = urllib.urlopen(AMA_URL, timeout=timeout).read() except socket.timeout: return None, "Connection timed out after {0} seconds while trying to fetch latest AMA version from {1}. Please check your network connectivity and firewall settings.".format(timeout, AMA_URL) except Exception as e: # More specific timeout detection error_str = str(e).lower() error_type = type(e).__name__ # Check for various timeout conditions if (error_type == 'timeout' or 'timeout' in error_str or 'timed out' in error_str or 'read timeout' in error_str or 'connect timeout' in error_str): return None, "Request timed out after {0} seconds while trying to fetch latest AMA version from {1}. This may be due to network connectivity issues or firewall restrictions.".format(timeout, AMA_URL) # Handle HTTP and URL errors if hasattr(e, 'code'): return None, "HTTP error {0} while trying to fetch latest AMA version from {1}. The documentation server may be temporarily unavailable.".format(e.code, AMA_URL) elif 'urlerror' in error_type.lower() or 'httperror' in error_type.lower(): return None, "Network error while trying to fetch latest AMA version from {1}: {0}".format(str(e), AMA_URL) elif 'name or service not known' in error_str: return None, "DNS resolution failed for {1}. Please check the URL and your network settings: {0}".format(str(e), AMA_URL) elif 'connection refused' in error_str: return None, "Connection refused while trying to connect to {1}. The server may be down: {0}".format(str(e), AMA_URL) elif 'network is unreachable' in error_str: return None, "Network is unreachable while trying to connect to {1}. Please check your network configuration: {0}".format(str(e), AMA_URL) else: return None, "Unexpected error while trying to fetch latest AMA version from {1}: {0}".format(str(e), AMA_URL) try: # Ensure we have a string for both Python 2 and 3 compatibility if sys.version_info[0] == 3 and isinstance(r, bytes): # Python 3: convert bytes to string r = r.decode('utf-8') # Python 2: urllib2.urlopen().read() returns str, which works fine with regex # Find all table rows in tbody and extract all 4th columns (Linux columns) # This approach is more robust and handles missing values and multiple rows tbody_pattern = r'(.*?)' tbody_match = re.search(tbody_pattern, r, re.DOTALL) if not tbody_match: return None, "Could not find version table in Microsoft documentation" tbody_content = tbody_match.group(1) # Find all table rows row_pattern = r']*>(.*?)' rows = re.findall(row_pattern, tbody_content, re.DOTALL) latest_version = None # Process each row to find the latest version # Since rows are in chronological order (newest first), we want the first non-empty row for row in rows: # Extract all cells from this row cell_pattern = r']*>(.*?)' cells = re.findall(cell_pattern, row, re.DOTALL) # Check if we have at least 4 columns and the 4th column (Linux) is not empty if len(cells) >= 4: linux_cell = cells[3] # 4th column (index 3) # Remove HTML tags and normalize whitespace # First replace
tags with spaces to avoid concatenation clean_content = re.sub(r']*>', ' ', linux_cell) # Remove all other HTML tags (including superscript) clean_content = re.sub(r'<[^>]+>', '', clean_content) # Normalize whitespace clean_content = re.sub(r'\s+', ' ', clean_content).strip() # Skip empty cells if not clean_content or clean_content.lower() in ['', 'none', 'n/a']: continue # Go to next row # Handle version ranges like "1.26.2-1.26.5" # Replace hyphens between versions with commas for easier parsing clean_content = re.sub(r'(\d+\.\d+\.\d+(?:\.\d+)?)\s*-\s*(\d+\.\d+\.\d+(?:\.\d+)?)', r'\1, \2', clean_content) # Find all version numbers in this cell (handles multiple versions) # More flexible regex that handles superscript and other text version_matches = re.findall(r'(\d+\.\d+\.\d+(?:\.\d+)?)', clean_content) if version_matches: # If multiple versions found, take the highest one from this cell cell_latest = None for version in version_matches: if cell_latest is None or not comp_versions_ge(cell_latest, version): cell_latest = version # Since this is the first non-empty row we found, use this as the latest latest_version = cell_latest break # Stop processing rows since we found the latest version if not latest_version: return None, "No version numbers found in Linux columns of Microsoft documentation" # Compare with current version if comp_versions_ge(curr_version, latest_version): return None, None # Current version is up to date else: return latest_version, None # New version available except Exception as e: return None, "Error parsing version information from Microsoft documentation: {0}".format(str(e)) return None, None # def comp_versions_ge(version1, version2): """ compare two versions, see if the first is newer than / the same as the second """ versions1 = [int(v) for v in version1.split(".")] versions2 = [int(v) for v in version2.split(".")] for i in range(max(len(versions1), len(versions2))): v1 = versions1[i] if i < len(versions1) else 0 v2 = versions2[i] if i < len(versions2) else 0 if v1 > v2: return True elif v1 < v2: return False return True def ask_update_old_version(ama_version, curr_ama_version): print("--------------------------------------------------------------------------------") print("You are currently running AMA Version {0}. There is a newer version\n"\ "available which may fix your issue (version {1}).".format(ama_version, curr_ama_version)) answer = get_input("Do you want to update? (y/n)", (lambda x : x.lower() in ['y','yes','n','no']),\ "Please type either 'y'/'yes' or 'n'/'no' to proceed.") # user does want to update if (answer.lower() in ['y', 'yes']): print("--------------------------------------------------------------------------------") print("Please follow the instructions given here:") print("\n https://docs.microsoft.com/en-us/azure/azure-monitor/agents/azure-monitor-agent-manage\n") return USER_EXIT # user doesn't want to update elif (answer.lower() in ['n', 'no']): print("Continuing on with troubleshooter...") print("--------------------------------------------------------------------------------") return NO_ERROR def check_ama(interactive): (ama_version, e) = get_package_version('azuremonitoragent') if e is not None: error_info.append((e,)) return ERR_AMA_INSTALL ama_version = ama_version.split('-')[0] if not comp_versions_ge(ama_version, '1.21.0'): error_info.append((ama_version,)) return ERR_OLD_AMA_VER print("Current AMA version: {0}".format(ama_version)) (newer_ama_version, e) = get_latest_ama_version(ama_version) if newer_ama_version is None: if e is None: # No error and no newer version found - current version is up to date print("AMA version is up to date (latest version)") return NO_ERROR else: # There was an error fetching the latest version print("Unable to determine latest AMA version") print("Error: {0}".format(e)) # Add error details to error_info for reporting error_info.append((e,)) # Check if we have general internet connectivity checked_internet = check_internet_connect() if checked_internet != NO_ERROR: # No internet connectivity - this is a broader issue print("Internet connectivity test also failed. Skipping version check...") print("This may indicate broader network connectivity issues.") print("--------------------------------------------------------------------------------") return ERR_GETTING_AMA_VER # Return error code for version check failure else: # Internet works but AMA version check failed - this might be specific to the documentation site print("Internet connectivity is working, but unable to access AMA documentation.") print("This could be due to firewall restrictions or temporary server issues.") print("The troubleshooter will continue, but version information may be outdated.") print("--------------------------------------------------------------------------------") return ERR_GETTING_AMA_VER # Return error code for version check failure else: # Found a newer version available print("Update available: {0} -> {1}".format(ama_version, newer_ama_version)) if interactive: if ask_update_old_version(ama_version, newer_ama_version) == USER_EXIT: return USER_EXIT return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/check_os.py ================================================ from __future__ import absolute_import import platform from error_codes import * from errors import error_info from helpers import find_vm_bits, find_vm_distro from . import supported_distros def format_alternate_versions(supported_dist, versions): """ print out warning if running the wrong version of OS """ last = versions.pop() if (versions == []): s = "{0}".format(last) else: s = "{0} or {1}".format(', '.join(versions), last) return s def check_vm_supported(vm_dist, vm_ver): if platform.machine() == 'aarch64': supported_dists = supported_distros.supported_dists_aarch64 else: supported_dists = supported_distros.supported_dists_x86_64 vm_supported = False # find VM distribution in supported list vm_supported_dist = None for supported_dist in (supported_dists.keys()): if (not vm_dist.lower().startswith(supported_dist)): continue vm_supported_dist = supported_dist # check if version is supported vm_ver_split = vm_ver.split('.') for supported_ver in (supported_dists[supported_dist]): supported_ver_split = supported_ver.split('.') vm_ver_match = True # try matching VM version with supported version for (idx, supported_ver_num) in enumerate(supported_ver_split): try: supported_ver_num = int(supported_ver_num) vm_ver_num = int(vm_ver_split[idx]) if (vm_ver_num is not supported_ver_num): vm_ver_match = False break except (IndexError, ValueError) as e: vm_ver_match = False break # check if successful in matching if (vm_ver_match): vm_supported = True break # check if any version successful in matching if (vm_supported): return NO_ERROR # VM distribution is supported, but not current version if (vm_supported_dist != None): versions = supported_dists[vm_supported_dist] alt_vers = format_alternate_versions(vm_supported_dist, versions) error_info.append((vm_dist, vm_ver, alt_vers)) return ERR_OS_VER # VM distribution isn't supported else: error_info.append((vm_dist,)) return ERR_OS def check_os(): if platform.machine() == 'x86_64': cpu_bits = find_vm_bits() if (not cpu_bits == '64-bit'): return ERR_BITS # get OS version (vm_dist, vm_ver, e) = find_vm_distro() if (vm_dist == None or vm_ver == None): error_info.append((e,)) return ERR_FINDING_OS # check if OS version is supported return check_vm_supported(vm_dist, vm_ver) ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/check_pkgs.py ================================================ import os from error_codes import * from errors import error_info from helpers import get_package_version, find_ama_version, is_metrics_configured METRICS_FIILE = "/etc/opt/microsoft/azuremonitoragent/config-cache/metricCounters.json" def check_packages(): # check azuremonitoragent rpm/dpkg (ama_vers, e) = find_ama_version() if (ama_vers == None): error_info.append((e,)) return ERR_AMA_INSTALL if (len(ama_vers) > 1): return ERR_MULTIPLE_AMA # find subcomponent binaries subcomponents = ['mdsd', 'agentlauncher', 'amacoreagent', 'fluent-bit'] if not os.path.isfile(METRICS_FIILE): return ERR_COUNTER_FILE_MISSING if is_metrics_configured(): subcomponents.append('MetricsExtension') subcomponents.append('telegraf') missed_subcomponent = [] for subcomponent in subcomponents: bin_file = '/opt/microsoft/azuremonitoragent/bin/{0}'.format(subcomponent) if (not os.path.isfile(bin_file)): missed_subcomponent.append(subcomponent) if len(missed_subcomponent) > 0: error_info.append((', '.join(missed_subcomponent),)) return ERR_SUBCOMPONENT_INSTALL return NO_ERROR def check_syslog(): pkg_version, e = get_package_version('rsyslog') if (pkg_version != None): return NO_ERROR pkg_version, e = get_package_version('syslog-ng') if (pkg_version != None): return NO_ERROR pkg_version, e = get_package_version('syslog-ng-core') if (pkg_version != None): return NO_ERROR return ERR_LOG_DAEMON ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/install.py ================================================ import os from error_codes import * from errors import error_info, is_error, print_errors from .check_os import check_os from .check_pkgs import check_packages, check_syslog from .check_ama import check_ama from helpers import find_package_manager def check_space(): """ check space in MB for each main directory """ dirnames = ["/etc", "/opt", "/var"] for dirname in dirnames: space = os.statvfs(dirname) free_space = space.f_bavail * space.f_frsize / 1024 / 1024 if (free_space < 500): error_info.append((dirname, free_space)) return ERR_FREE_SPACE return NO_ERROR def check_pkg_manager(): pkg_manager = find_package_manager() if (pkg_manager == ""): return ERR_PKG_MANAGER return NO_ERROR def check_syslog_user(): with open('/etc/passwd', 'r') as fp: for line in fp: if line.startswith('syslog:'): return NO_ERROR return ERR_SYSLOG_USER def check_installation(interactive, err_codes=True, prev_success=NO_ERROR): """ check all packages are installed """ print("CHECKING INSTALLATION...") success = prev_success # check Supported OS / version print("Checking if running a supported OS version...") checked_os = check_os() if (is_error(checked_os)): return print_errors(checked_os) else: success = print_errors(checked_os) # check Available disk space print("Checking if enough disk space is available...") checked_space = check_space() if (is_error(checked_space)): return print_errors(checked_space) else: success = print_errors(checked_space) # check Package manager (dpkg/rpm) print("Checking if machine has a supported package manager...") checked_pkg_manager = check_pkg_manager() if (is_error(checked_pkg_manager)): return print_errors(checked_pkg_manager) else: success = print_errors(checked_pkg_manager) # check package + subcomponents installation states print("Checking if packages and subcomponents are installed correctly...") checked_packages = check_packages() if (is_error(checked_packages)): return print_errors(checked_packages) else: success = print_errors(checked_packages) # check AMA version installed print("Checking if running a supported version of AMA...") checked_ama = check_ama(interactive) if (is_error(checked_ama)): return print_errors(checked_ama) else: success = print_errors(checked_ama) # check Existence of rsyslog or syslog-ng print("Checking if rsyslog or syslog-ng exists...") checked_syslog = check_syslog() if (is_error(checked_syslog)): return print_errors(checked_syslog) else: success = print_errors(checked_syslog) # check Syslog user created successfully print("Checking if syslog user exists...") checked_syslog_user = check_syslog_user() if (is_error(checked_syslog_user)): return print_errors(checked_syslog_user) else: success = print_errors(checked_syslog_user) print("============================================") return success ================================================ FILE: AzureMonitorAgent/ama_tst/modules/install/supported_distros.py ================================================ supported_dists_x86_64 = {'redhat' : ['7', '8', '9', '10'], # Rhel 'centos' : ['7', '8'], # CentOS 'oracle' : ['7', '8', '9'], # Oracle 'ol' : ['7', '8', '9'], # Oracle Linux 'debian' : ['9', '10', '11', '12', '13'], # Debian 'ubuntu' : ['16.04', '18.04', '20.04', '22.04', '24.04'], # Ubuntu 'suse' : ['12', '15', '16'], 'sles' : ['12', '15', '16'], # SLES 'mariner' : ['2'], # Mariner 'azurelinux' : ['3'], # Azure Linux / Mariner 3 'rocky' : ['8', '9'], # Rocky 'alma' : ['8', '9'], # Alma 'opensuse' : ['15'], # openSUSE 'amzn' : ['2', '2023'] # Amazon Linux 2 } supported_dists_aarch64 = {'redhat' : ['8', '9', '10'], # Rhel 'ubuntu' : ['18.04', '20.04', '22.04', '24.04'], # Ubuntu 'alma' : ['8'], # Alma 'centos' : ['7'], # CentOS 'mariner' : ['2'], # Mariner 2 'azurelinux' : ['3'], # Azure Linux / Mariner 3 'sles' : ['15', '16'], # SLES 'debian' : ['11', '12', '13'], # Debian 'rocky linux' : ['8', '9'], # Rocky 'rocky' : ['8', '9'] # Rocky } ================================================ FILE: AzureMonitorAgent/ama_tst/modules/logcollector.py ================================================ import datetime import glob import os import platform import shutil import json import helpers from error_codes import * from connect.check_imds import check_metadata from metrics_troubleshooter.metrics_troubleshooter import run_metrics_troubleshooter DPKG_CMD = "dpkg -s azuremonitoragent" RPM_CMD = "rpm -qi azuremonitoragent" PS_CMD = "ps -ef | grep {0} | grep -v grep" OPENSSL_CMD = "echo | openssl s_client -connect {0}:443 -brief" SYSTEMCTL_CMD = "systemctl status {0} --no-pager" JOURNALCTL_CMD = "journalctl -u {0} --no-pager --since \"30 days ago\" > {1}" PS_CMD_CPU = "ps aux --sort=-pcpu | head -10" PS_CMD_RSS = "ps aux --sort -rss | head -10" PS_CMD_VSZ = "ps aux --sort -vsz | head -10" DU_CMD = "du -h -d 1 {0} /var/opt/microsoft/azuremonitoragent/events" VAR_DU_CMD = "du -h -d 1 {0} /var" LS_CMD = "ls -al {0}" NAMEI_CMD = "namei -om {0}" TAIL_SYSLOG_CMD = "tail -10000 /var/log/{0} > {1}" ArcSettingsFile = '/var/opt/azcmagent/localconfig.json' PERMISSION_CHECK_FILES = ["/etc/opt/microsoft/azuremonitoragent/config-cache", "/etc/opt/microsoft/azuremonitoragent", "/var/opt/microsoft/azuremonitoragent", "/var/run/azuremonitoragent", "/opt/microsoft/azuremonitoragent", "/run/azuremonitoragent", "/var/lib/waagent/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-*"] # File copying functions def copy_file(src, dst): if (os.path.isfile(src)): print("Copying file {0}".format(src)) try: if (not os.path.isdir(dst)): os.mkdir(dst) shutil.copy2(src, dst) except Exception as e: print("ERROR: Could not copy {0}: {1}".format(src, e)) print("Skipping over file {0}".format(src)) else: print("File {0} doesn't exist, skipping".format(src)) return def copy_dircontents(src, dst): if (os.path.isdir(src)): print("Copying contents of directory {0}".format(src)) try: shutil.copytree(src, dst) auth_token_path = os.path.join(dst, "metrics_configs", "AuthToken-MSI.json") if (os.path.isfile(auth_token_path)): print("Found AuthToken-MSI.json") try: with open(auth_token_path, 'r') as auth_token: auth_token_json = json.load(auth_token) if (auth_token_json and "access_token" in auth_token_json): print("Removing access_token value from AuthToken-MSI.json") auth_token_json["access_token"] = "" with open(auth_token_path, 'w') as auth_token: json.dump(auth_token_json, auth_token, indent=4) print("Successfully removed access_token value from AuthToken-MSI.json") except Exception as e: print("ERROR: Could not decode JSON from {0}: {1}".format(auth_token_path, e)) except Exception as e: print("ERROR: Could not copy {0}: {1}".format(src, e)) print("Skipping over contents of directory {0}".format(src)) else: print("Directory {0} doesn't exist, skipping".format(src)) return # Log collecting functions def collect_process_environ(output_dirpath, process_name, outfile_handle=None): """ Collect environment variables for a specific process. If outfile_handle is provided, writes to that file handle (for main log). If outfile_handle is None, creates a separate file in the process directory. """ if outfile_handle is None: # Create separate file mode process_dir = os.path.join(output_dirpath, process_name) if not os.path.isdir(process_dir): os.makedirs(process_dir) environ_file_path = os.path.join(process_dir, "{0}_environ.txt".format(process_name)) try: with open(environ_file_path, 'w') as environ_file: _write_process_environ_data(environ_file, process_name, separate_file=True) print("{0} environment variables saved to {1}".format(process_name.upper(), environ_file_path)) except Exception as e: print("ERROR: Could not create {0} environment variables file: {1}".format(process_name, e)) else: # Write to existing file handle mode (for main log) _write_process_environ_data(outfile_handle, process_name, separate_file=False) def _write_process_environ_data(file_handle, process_name, separate_file=True): """Helper function to write process environment data to a file handle""" if separate_file: # Format for separate file file_handle.write("{0} Environment Variables Collection\n".format(process_name.upper())) file_handle.write("=====================================\n") file_handle.write("Collected on: {0}\n\n".format(datetime.datetime.utcnow().isoformat())) else: # Format for main log file file_handle.write("{0} Environment Variables:\n".format(process_name.upper())) file_handle.write("========================================\n") # Get all process PIDs process_pids_output = helpers.run_cmd_output("pidof {0}".format(process_name)) if process_pids_output.strip(): process_pids = process_pids_output.strip().split() for pid in process_pids: file_handle.write("PID: {0}\n".format(pid)) environ_path = "/proc/{0}/environ".format(pid) if os.path.isfile(environ_path): try: with open(environ_path, 'rb') as proc_environ_file: environ_data = proc_environ_file.read() # Convert null-separated variables to readable format # Use try/except for Python 2/3 compatibility with decode errors parameter try: environ_vars = environ_data.decode('utf-8', errors='replace').replace('\x00', '') except TypeError: # Python 2.6 doesn't support errors parameter environ_vars = environ_data.decode('utf-8').replace('\x00', '') file_handle.write("{0}\n".format(environ_vars)) except Exception as e: file_handle.write("Error reading environment variables for PID {0}: {1}\n".format(pid, e)) else: file_handle.write("Environment file not found for PID {0}\n".format(pid)) file_handle.write("=====================================\n") else: file_handle.write("No {0} processes found\n".format(process_name)) if not separate_file: # Add separator for main log file file_handle.write("--------------------------------------------------------------------------------\n") def collect_logs(output_dirpath, pkg_manager): # collect MDSD information copy_file("/etc/default/azuremonitoragent", os.path.join(output_dirpath,"mdsd")) copy_file("/var/opt/microsoft/azuremonitoragent/events/taskstate.json", os.path.join(output_dirpath,"mdsd")) copy_dircontents("/var/opt/microsoft/azuremonitoragent/log", os.path.join(output_dirpath,"mdsd","logs")) # collect MDSD environment variables collect_process_environ(output_dirpath, "mdsd") # collect AMA Core Agent environment variables collect_process_environ(output_dirpath, "amacoreagent") # collect AMA DCR copy_dircontents("/etc/opt/microsoft/azuremonitoragent", os.path.join(output_dirpath,"DCR")) # get all AzureMonitorLinuxAgent-* directory names for config_dir in filter((lambda x : x.startswith("Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-")), os.listdir("/var/lib/waagent")): # collect AMA config and status information for all AzureMonitorLinuxAgent-* directories ver = (config_dir.split('-'))[-1] copy_dircontents(os.path.join("/var/lib/waagent",config_dir,"status"), os.path.join(output_dirpath,ver+"-status")) copy_dircontents(os.path.join("/var/lib/waagent",config_dir,"config"), os.path.join(output_dirpath,ver+"-config")) # collect system logs system_logs = "" if (pkg_manager == "dpkg"): system_logs = "syslog" elif (pkg_manager == "rpm"): system_logs = "messages" if (system_logs != ""): for systemlog_file in filter((lambda x : x.startswith(system_logs)), os.listdir("/var/log")): helpers.run_cmd_output(TAIL_SYSLOG_CMD.format(systemlog_file, os.path.join(output_dirpath,"system_logs"))) # collect rsyslog information (if present) copy_file("/etc/rsyslog.conf", os.path.join(output_dirpath,"rsyslog")) copy_dircontents("/etc/rsyslog.d", os.path.join(output_dirpath,"rsyslog","rsyslog.d")) if (os.path.isfile("/etc/rsyslog.conf")): helpers.run_cmd_output(JOURNALCTL_CMD.format("rsyslog", os.path.join(output_dirpath,"rsyslog","journalctl_output.log"))) # collect syslog-ng information (if present) copy_dircontents("/etc/syslog-ng", os.path.join(output_dirpath,"syslog-ng")) return def collect_arc_logs(output_dirpath, pkg_manager): # collect GC Extension logs copy_dircontents("/var/lib/GuestConfig/ext_mgr_logs", os.path.join(output_dirpath,"GC_Extension")) # collect AMA Extension logs for config_dir in filter((lambda x : x.startswith("Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-")), os.listdir("/var/lib/GuestConfig/extension_logs")): # collect AMA config and status information for all AzureMonitorLinuxAgent-* directories ver = (config_dir.split('-'))[-1] copy_dircontents(os.path.join("/var/lib/GuestConfig/extension_logs",config_dir), os.path.join(output_dirpath,ver+"-extension_logs")) copy_file(ArcSettingsFile, os.path.join(output_dirpath,"Arc")) # collect logs same to both Arc + Azure VM collect_logs(output_dirpath, pkg_manager) print("Arc logs collected") return def collect_azurevm_logs(output_dirpath, pkg_manager): # collect waagent logs for waagent_file in filter((lambda x : x.startswith("waagent.log")), os.listdir("/var/log")): copy_file(os.path.join("/var/log",waagent_file), os.path.join(output_dirpath,"waagent")) # collect AMA Extension logs copy_dircontents("/var/log/azure/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent", os.path.join(output_dirpath,"Microsoft.Azure.Monitor.AzureMonitorLinuxAgent")) # collect logs same to both Arc + Azure VM collect_logs(output_dirpath, pkg_manager) print("Azure VM logs collected") return def collect_metrics_logs(output_dirpath): """ Run the metrics troubleshooter and collect any MdmDataCollectionOutput_*.tar.gz files. """ print("Running metrics troubleshooter...") # Run the metrics troubleshooter (it produces MdmDataCollectionOutput_*.tar.gz) run_metrics_troubleshooter(interactive=False) # Find and copy any MdmDataCollectionOutput_*.tar.gz files from common locations metrics_output_patterns = [ "/tmp/MdmDataCollectionOutput_*.tar.gz", "/var/tmp/MdmDataCollectionOutput_*.tar.gz", os.path.join(os.getcwd(), "MdmDataCollectionOutput_*.tar.gz") ] metrics_dir = os.path.join(output_dirpath, "metrics") files_found = False for pattern in metrics_output_patterns: for metrics_file in glob.glob(pattern): if not files_found: if not os.path.isdir(metrics_dir): os.makedirs(metrics_dir) files_found = True print("Copying metrics output file: {0}".format(metrics_file)) try: shutil.copy2(metrics_file, metrics_dir) except Exception as e: print("ERROR: Could not copy {0}: {1}".format(metrics_file, e)) if not files_found: print("No MdmDataCollectionOutput_*.tar.gz files found.") else: print("Metrics logs collected") return # Outfile function def create_outfile(output_dirpath, logs_date, pkg_manager): with open(os.path.join(output_dirpath,"amalinux.out"), 'w') as outfile: outfile.write("Log Collection Start Time: {0}\n".format(logs_date)) outfile.write("--------------------------------------------------------------------------------\n") # detected OS + version vm_dist, vm_ver, _ = helpers.find_vm_distro() if (vm_dist and vm_ver): outfile.write("Linux OS detected: {0}\n".format(vm_dist)) outfile.write("Linux OS version detected: {0}\n".format(vm_ver)) else: outfile.write("Indeterminate OS.\n") # detected package manager if (pkg_manager != ""): outfile.write("Package manager detected: {0}\n".format(pkg_manager)) else: outfile.write("Indeterminate package manager.\n") outfile.write("--------------------------------------------------------------------------------\n") # uname info os_uname = os.uname() outfile.write("Hostname: {0}\n".format(os_uname[1])) outfile.write("Release Version: {0}\n".format(os_uname[2])) outfile.write("Linux UName: {0}\n".format(os_uname[3])) outfile.write("Machine Type: {0}\n".format(os_uname[4])) outfile.write("--------------------------------------------------------------------------------\n") # python version outfile.write("Python Version: {0}\n".format(platform.python_version())) outfile.write("--------------------------------------------------------------------------------\n") # /etc/os-release if (os.path.isfile("/etc/os-release")): outfile.write("Contents of /etc/os-release:\n") with open("/etc/os-release", 'r') as os_info: for line in os_info: outfile.write(line) outfile.write("--------------------------------------------------------------------------------\n") # VM Metadata attributes = ['azEnvironment', 'resourceId', 'location'] outfile.write("VM Metadata from IMDS:\n") for attr in attributes: attr_result = helpers.geninfo_lookup(attr) if (not attr_result) and (check_metadata() == NO_ERROR): attr_result = helpers.geninfo_lookup(attr) if (attr_result != None): outfile.write("{0}: {1}\n".format(attr, attr_result)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # AMA install status (ama_vers, _) = helpers.find_ama_version() (ama_installed, ama_unique) = helpers.check_ama_installed(ama_vers) outfile.write("AMA Install Status: {0}\n".format("installed" if ama_installed else "not installed")) if (ama_installed): if (not ama_unique): outfile.write("Multiple AMA versions detected: {0}\n".format(', '.join(ama_vers))) else: outfile.write("AMA Version: {0}\n".format(ama_vers[0])) outfile.write("--------------------------------------------------------------------------------\n") # connection to endpoints wkspc_id, wkspc_region, agent_settings, e = helpers.find_dcr_workspace() if e == None: outfile.write("Workspace ID: {0}\n".format(str(wkspc_id))) outfile.write("Workspace region: {0}\n".format(str(wkspc_region))) outfile.write("--------------------------------------------------------------------------------\n") if agent_settings != {}: outfile.write("AgentSettinsgs file found: {0}\n".format(str(agent_settings))) # AMA package info (dpkg/rpm) if (pkg_manager == "dpkg"): outfile.write("Output of command: {0}\n".format(DPKG_CMD)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(DPKG_CMD)) outfile.write("--------------------------------------------------------------------------------\n") elif (pkg_manager == "rpm"): outfile.write("Output of command: {0}\n".format(RPM_CMD)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(RPM_CMD)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # ps -ef output for process in ["azuremonitoragent", "mdsd", "telegraf"]: ps_process_cmd = PS_CMD.format(process) outfile.write("Output of command: {0}\n".format(ps_process_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(ps_process_cmd)) outfile.write("--------------------------------------------------------------------------------\n") # process environment variables output collect_process_environ(output_dirpath, "mdsd", outfile) collect_process_environ(output_dirpath, "amacoreagent", outfile) # rsyslog / syslog-ng status via systemctl for syslogd in ["rsyslog", "syslog-ng"]: systemctl_cmd = SYSTEMCTL_CMD.format(syslogd) outfile.write("Output of command: {0}\n".format(systemctl_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(systemctl_cmd)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # ps aux output for cmd in [PS_CMD_CPU, PS_CMD_RSS, PS_CMD_VSZ]: outfile.write("Output of command: {0}\n".format(cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(cmd)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # du output on events folder for flag in ["", "--apparent-size"]: du_full_cmd = DU_CMD.format(flag) outfile.write("Output of command: {0}\n".format(du_full_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(du_full_cmd)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # du output on /var folder for flag in ["", "--apparent-size"]: du_full_cmd = VAR_DU_CMD.format(flag) outfile.write("Output of command: {0}\n".format(du_full_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(du_full_cmd)) outfile.write("--------------------------------------------------------------------------------\n") # file permission check for file in PERMISSION_CHECK_FILES: file_permission_cmd = LS_CMD.format(file) outfile.write("Output of command: {0}\n".format(file_permission_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(file_permission_cmd)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") # parent directory permission check for file in PERMISSION_CHECK_FILES: dir_permission_cmd = NAMEI_CMD.format(file) outfile.write("Output of command: {0}\n".format(dir_permission_cmd)) outfile.write("========================================\n") outfile.write(helpers.run_cmd_output(dir_permission_cmd)) outfile.write("--------------------------------------------------------------------------------\n") outfile.write("--------------------------------------------------------------------------------\n") ### MAIN FUNCTION BODY BELOW ### def run_logcollector(output_location): # check if Arc is being used is_arc_vm = helpers.is_arc_installed() # create directory to hold copied logs vm_type = "azurearc" if is_arc_vm else "azurevm" logs_date = str(datetime.datetime.utcnow().isoformat()).replace(":", ".") # ':' causes issues with tar output_dirname = "amalogs-{0}-{1}".format(vm_type, logs_date) output_dirpath = os.path.join(output_location, output_dirname) try: os.mkdir(output_dirpath) except OSError as e: print("ERROR: Could not create output directory: {0}".format(e)) return # get VM information needed for log collection pkg_manager = helpers.find_package_manager() # collect the logs if (is_arc_vm): print("Azure Arc detected, collecting logs for Azure Arc.") print("--------------------------------------------------------------------------------") collect_arc_logs(output_dirpath, pkg_manager) else: print("Azure Arc not detected, collected logs for Azure VM.") print("--------------------------------------------------------------------------------") collect_azurevm_logs(output_dirpath, pkg_manager) print("--------------------------------------------------------------------------------") # create out file (for simple checks) print("Creating 'amalinux.out' file") create_outfile(output_dirpath, logs_date, pkg_manager) print("--------------------------------------------------------------------------------") # collect metrics troubleshooter logs print("Collecting metrics troubleshooter logs...") collect_metrics_logs(output_dirpath) print("--------------------------------------------------------------------------------") # zip up logs print("Zipping up logs and removing temporary output directory") tgz_filename = "{0}.tgz".format(output_dirname) tgz_filepath = os.path.join(output_location, tgz_filename) print("--------------------------------------------------------------------------------") print(helpers.run_cmd_output("cd {0}; tar -zcf {1} {2}".format(output_location, tgz_filename, output_dirname))) # This makes archive not readable by anyone else but the user who created it print("Setting permissions on the archive to 600 so only the user who created it can read it") print("--------------------------------------------------------------------------------") os.chmod(tgz_filepath, 0o600) shutil.rmtree(output_dirpath, ignore_errors=True) print("--------------------------------------------------------------------------------") print("You can find the AMA logs at the following location: {0}".format(tgz_filepath)) return ================================================ FILE: AzureMonitorAgent/ama_tst/modules/main.py ================================================ import os import sys from helpers import get_input from logcollector import run_logcollector from error_codes import * from errors import get_input, is_error, err_summary from install.install import check_installation from connect.connect import check_connection from general_health.general_health import check_general_health from high_cpu_mem.high_cpu_mem import check_high_cpu_memory from syslog_tst.syslog import check_syslog from custom_logs.custom_logs import check_custom_logs from metrics_troubleshooter.metrics_troubleshooter import run_metrics_troubleshooter # check to make sure the user is running as root def check_sudo(): if (os.geteuid() != 0): print("The troubleshooter is not currently being run as root. In order to have accurate results, we ask that you run this troubleshooter as root.") print("NOTE: it will not add, modify, or delete any files without express permission.") print("Please try running the troubleshooter again with 'sudo'. Thank you!") return False else: return True def check_all(interactive): """ Run all troubleshooter checks, continuing even if errors occur. Collects all results and reports the most severe issue at the end. """ checks = [ ("Installation", check_installation), ("Connection", check_connection), ("General Health", check_general_health), ("High CPU/Memory Usage", check_high_cpu_memory), ("Syslog", check_syslog), ("Custom logs", check_custom_logs), ("Metrics", run_metrics_troubleshooter), ] results = [] overall_status = NO_ERROR for i, (check_name, check_func) in enumerate(checks, 1): print("================================================================================") print("Running check {0}/7: {1}...".format(i, check_name)) try: result = check_func(interactive) results.append((check_name, result)) # Track the most severe error (higher error codes are more severe) if is_error(result) and result > overall_status: overall_status = result elif not is_error(result) and result > overall_status and overall_status == NO_ERROR: overall_status = result # Print immediate result for this check if is_error(result): print("[ERROR] {0}: ERROR (code {1})".format(check_name, result)) elif result != NO_ERROR: print("[WARN] {0}: WARNING (code {1})".format(check_name, result)) else: print("[OK] {0}: OK".format(check_name)) except Exception as e: print("[EXCEPTION] {0}: EXCEPTION - {1}".format(check_name, str(e))) results.append((check_name, "EXCEPTION: {0}".format(str(e)))) overall_status = ERR_FOUND # Set a generic error code # Summary of all results print("\n================================================================================") print("SUMMARY OF ALL CHECKS:") print("================================================================================") for check_name, result in results: if isinstance(result, str) and result.startswith("EXCEPTION"): print("[EXCEPTION] {0}: {1}".format(check_name, result)) elif is_error(result): print("[ERROR] {0}: ERROR (code {1})".format(check_name, result)) elif result != NO_ERROR: print("[WARN] {0}: WARNING (code {1})".format(check_name, result)) else: print("[OK] {0}: OK".format(check_name)) return overall_status def collect_logs(): # get output directory for logs print("Please input an existing, absolute filepath to a directory where the output for the zip file will be placed upon completion.") output_location = get_input("Output Directory", (lambda x : os.path.isdir(x)), \ "Please input an existing, absolute filepath.") print("Collecting AMA logs...") print("================================================================================") run_logcollector(output_location) def print_results(success): print("================================================================================") print("================================================================================") # print out all errors/warnings if (len(err_summary) > 0): print("ALL ERRORS/WARNINGS ENCOUNTERED:") for err in err_summary: print(" {0}".format(err)) print("--------------------------------------------------------------------------------") # no errors found if (success == NO_ERROR): print("No errors were found.") # user requested to exit elif (success == USER_EXIT): return # error found else: print("Please review the errors found above.") ''' give information to user about next steps ''' def print_next_steps(): print("================================================================================") print("If you still have an issue, please run the troubleshooter again and collect the logs for AMA.\n"\ "In addition, please include the following information:\n"\ " - Azure Subscription ID where the Log Analytics Workspace is located\n"\ " - Workspace ID the agent has been onboarded to\n"\ " - Workspace Name\n"\ " - Region Workspace is located\n"\ " - Pricing Tier assigned to the Workspace\n"\ " - Linux Distribution on the VM\n"\ " - Azure Monitor Agent Version") print("================================================================================") print("Restarting AMA can solve some of the problems. If you need to restart Azure Monitor Agent on this machine, "\ "please execute the following commands as the root user:") print(" $ cd /var/lib/waagent/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent-/") print(" $ ./shim.sh -disable") print(" $ ./shim.sh -enable") ### MAIN FUNCTION BODY BELOW ### def run_troubleshooter(): # check if running as sudo if (not check_sudo()): return # run all checks from command line if len(sys.argv) > 1 and sys.argv[1] == '-A': success = check_all(False) print_results(success) print_next_steps() return # run log collector from command line if len(sys.argv) > 1 and sys.argv[1] == '-L': collect_logs() return # check if want to run again run_again = True print("Welcome to the Azure Monitor Linux Agent Troubleshooter! What is your issue?\n") while (run_again): print("================================================================================\n"\ # TODO: come up with scenarios "1: Installation failures. \n"\ "2: Agent doesn't start or cannot connect to Log Analytics service.\n"\ "3: Agent in unhealthy state. \n"\ "4: Agent consuming high CPU/memory. \n"\ "5: Syslog not flowing. \n"\ "6: Custom logs not flowing. \n"\ "7: Metrics not flowing.\n"\ "================================================================================\n"\ "A: Run through all scenarios.\n"\ "L: Collect the logs for AMA.\n"\ "Q: Press 'Q' to quit.\n"\ "================================================================================") switcher = { '1': check_installation, '2': check_connection, '3': check_general_health, '4': check_high_cpu_memory, '5': check_syslog, '6': check_custom_logs, '7': run_metrics_troubleshooter, 'A': check_all } issue = get_input("Please select an option",\ (lambda x : x.lower() in ['1','2','3','4','5','6','7','q','quit','l','a']),\ "Please enter an integer corresponding with your issue (1-6) to\n"\ "continue, 'A' to run through all scenarios, 'L' to run the log collector, or 'Q' to quit.") # quit troubleshooter if (issue.lower() in ['q','quit']): print("Exiting the troubleshooter...") return # collect logs if (issue.lower() == 'l'): collect_logs() return # silent vs interactive mode print("--------------------------------------------------------------------------------") print("The troubleshooter can be run in two different modes.\n"\ " - Silent Mode runs through with no input required\n"\ " - Interactive Mode includes extra checks that require input") mode = get_input("Do you want to run the troubleshooter in silent (s) or interactive (i) mode?",\ (lambda x : x.lower() in ['s','silent','i','interactive','q','quit']),\ "Please enter 's'/'silent' to run silent mode, 'i'/'interactive' to run \n"\ "interactive mode, or 'q'/'quit' to quit.") if (mode.lower() in ['q','quit']): print("Exiting the troubleshooter...") return elif (mode.lower() in ['s','silent']): print("Running troubleshooter in silent mode...") interactive_mode = False elif (mode.lower() in ['i','interactive']): print("Running troubleshooter in interactive mode...") interactive_mode = True # run troubleshooter section = switcher.get(issue.upper(), lambda: "Invalid input") print("================================================================================") success = section(interactive=interactive_mode) print_results(success) # if user ran single scenario, ask if they want to run again if (issue in ['1', '2', '3', '4', '5', '6', '7']): run_again = get_input("Do you want to run another scenario? (y/n)",\ (lambda x : x.lower() in ['y','yes','n','no']),\ "Please type either 'y'/'yes' or 'n'/'no' to proceed.") if (run_again.lower() in ['y', 'yes']): print("Please select another scenario below:") elif (run_again.lower() in ['n', 'no']): run_again = False else: run_again = False print_next_steps() return if __name__ == '__main__': run_troubleshooter() ================================================ FILE: AzureMonitorAgent/ama_tst/modules/metrics_troubleshooter/__init__.py ================================================ # metrics troubleshooter script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/metrics_troubleshooter/metrics_troubleshooter.py ================================================ import os import subprocess from error_codes import * # Resolve absolute path to the script SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) TROUBLESHOOTER_FILE = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..", "metrics_troubleshooter.sh")) def run_metrics_troubleshooter(interactive): """ Executes the metrics troubleshooter script. """ if not os.path.exists(TROUBLESHOOTER_FILE): print("Metrics Troubleshooter script not found at: {}".format(TROUBLESHOOTER_FILE)) return ERR_FOUND status = None if interactive: print("================================================================================") print("Metrics Troubleshooter does not support interactive mode yet.") print("The troubleshooter produces `MdmDataCollectionOutput_.*tar.gz`, which is required for investigating the issue.") try: proc = subprocess.Popen( ["/bin/sh", TROUBLESHOOTER_FILE], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = proc.communicate() status = proc.returncode if status != 0: print("Error ({}): {}".format(status, stderr.strip())) # raise Exception or return False here if needed else: print("Troubleshooter output: {}".format(stdout.strip())) except Exception as e: print("Unexpected error: {}".format(str(e))) return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/syslog_tst/__init__.py ================================================ # Syslog check helper script for AMA ================================================ FILE: AzureMonitorAgent/ama_tst/modules/syslog_tst/check_conf.py ================================================ import os from error_codes import * from errors import error_info from helpers import geninfo_lookup, run_cmd_output CONF_ACCESS_CMD = 'sudo -u syslog test -r {0}; echo "$?"' SOCKET_ACCESS_CMD = 'sudo -u syslog test -{0} {1}; echo "$?"' AMA_SOCKET = "/run/azuremonitoragent/default_syslog.socket" def check_conf_files(): # update syslog destination path with correct location syslog_dest = geninfo_lookup('SYSLOG_DEST') if (syslog_dest == None): return ERR_SYSLOG # verify syslog destination exists / not empty / accessible by syslog user if (not os.path.isfile(syslog_dest)): error_info.append(('file', syslog_dest)) return ERR_FILE_MISSING if (os.stat(syslog_dest).st_size == 0): error_info.append((syslog_dest,)) return ERR_FILE_EMPTY if (run_cmd_output(CONF_ACCESS_CMD.format(syslog_dest)).strip() != '0'): error_info.append(('file', syslog_dest, 'read')) return ERR_CONF_FILE_PERMISSION return NO_ERROR def check_socket(): if (not os.path.exists(AMA_SOCKET)): error_info.append(('socket', AMA_SOCKET)) return ERR_FILE_MISSING if (run_cmd_output(SOCKET_ACCESS_CMD.format('r', AMA_SOCKET)).strip() != '0'): error_info.append(('socket', AMA_SOCKET, 'read')) return ERR_CONF_FILE_PERMISSION if (run_cmd_output(SOCKET_ACCESS_CMD.format('w', AMA_SOCKET)).strip() != '0'): error_info.append(('socket', AMA_SOCKET, 'write')) return ERR_CONF_FILE_PERMISSION return NO_ERROR ================================================ FILE: AzureMonitorAgent/ama_tst/modules/syslog_tst/check_rsysng.py ================================================ import subprocess from error_codes import * from errors import error_info from helpers import general_info RSYSLOG_CONF = "/etc/rsyslog.d/10-azuremonitoragent-omfwd.conf" SYSLOG_NG_CONF = "/etc/syslog-ng/conf.d/azuremonitoragent-tcp.conf" # check syslog with systemctl def check_sys_systemctl(service): try: sys_status = subprocess.check_output(['systemctl', 'status', service], \ universal_newlines=True, stderr=subprocess.STDOUT) sys_lines = sys_status.split('\n') for line in sys_lines: line = line.strip() if line.startswith('Active: '): stripped_line = line.lstrip('Active: ') # exists and running correctly if stripped_line.startswith('active (running) since '): return NO_ERROR # exists but not running correctly else: error_info.append((service, stripped_line, 'systemctl')) return ERR_SERVICE_STATUS except subprocess.CalledProcessError as e: # service not on machine if (e.returncode == 4): return ERR_SYSLOG else: error_info.append((service, e.output, 'systemctl')) return ERR_SERVICE_STATUS def check_services(): global general_info checked_rsyslog = check_sys_systemctl('rsyslog') # rsyslog successful if (checked_rsyslog == NO_ERROR): general_info['SYSLOG_DEST'] = RSYSLOG_CONF return NO_ERROR checked_syslog_ng = check_sys_systemctl('syslog-ng') # syslog-ng successful if (checked_syslog_ng == NO_ERROR): general_info['SYSLOG_DEST'] = SYSLOG_NG_CONF return NO_ERROR # ran into error trying to get syslog if ((checked_rsyslog==ERR_SERVICE_STATUS) or (checked_syslog_ng==ERR_SERVICE_STATUS)): return ERR_SERVICE_STATUS return ERR_SYSLOG ================================================ FILE: AzureMonitorAgent/ama_tst/modules/syslog_tst/syslog.py ================================================ from error_codes import * from errors import is_error, print_errors from .check_conf import check_conf_files, check_socket from .check_rsysng import check_services def check_syslog(interactive, prev_success=NO_ERROR): print("CHECKING FOR SYSLOG ISSUES...") success = prev_success # check rsyslog / syslogng running print("Checking if machine has rsyslog or syslog-ng running...") checked_services = check_services() if (is_error(checked_services)): return print_errors(checked_services) else: success = print_errors(checked_services) # check for rsyslog / syslog-ng configuration files print("Checking for syslog configuration files...") checked_conf_files = check_conf_files() if (is_error(checked_conf_files)): return print_errors(checked_conf_files) else: success = print_errors(checked_conf_files) # check for syslog socket existence and permissions print("Checking for syslog socket...") checked_socket = check_socket() if (is_error(checked_socket)): return print_errors(checked_socket) else: success = print_errors(checked_socket) return success ================================================ FILE: AzureMonitorAgent/apply_version.sh ================================================ #! /bin/bash source ./agent.version echo "AGENT_VERSION=$AGENT_VERSION" echo "MDSD_DEB_PACKAGE_NAME=$MDSD_DEB_PACKAGE_NAME" echo "MDSD_RPM_PACKAGE_NAME=$MDSD_RPM_PACKAGE_NAME" # updating HandlerManifest.json # check for "version": "x.x.x", sed -i "s/\"version\".*$/\"version\": \"$AGENT_VERSION\",/g" HandlerManifest.json # updating agent.py sed -i "s/^BundleFileNameDeb = .*$/BundleFileNameDeb = '$MDSD_DEB_PACKAGE_NAME'/" agent.py sed -i "s/^BundleFileNameRpm = .*$/BundleFileNameRpm = '$MDSD_RPM_PACKAGE_NAME'/" agent.py sed -i "s/AMA_VERSION/$AGENT_VERSION/" services/metrics-extension-otlp.service sed -i "s/AMA_VERSION/$AGENT_VERSION/" services/metrics-extension-cmv2.service # updating manifest.xml # check ... sed -i -e "s|[0-9a-z.]\{1,\}|$AGENT_VERSION|g" manifest.xml ================================================ FILE: AzureMonitorAgent/azuremonitoragentextension.logrotate ================================================ /var/log/azure/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent/extension.log { copytruncate rotate 7 daily missingok notifempty delaycompress compress size 10M } /var/log/azure/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent/CommandExecution.log { copytruncate rotate 7 daily missingok notifempty delaycompress compress size 10M } /var/log/azure/Microsoft.Azure.Monitor.AzureMonitorLinuxAgent/telegraf.log { copytruncate rotate 7 daily missingok notifempty delaycompress compress size 10M } ================================================ FILE: AzureMonitorAgent/manifest.xml ================================================ Microsoft.Azure.Monitor AzureMonitorLinuxAgent 1.5.124 VmRole Microsoft Azure Monitoring Agent for Linux true https://docs.microsoft.com/en-us/azure/azure-monitor/learn/quick-collect-linux-computer http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx https://msazure.visualstudio.com/DefaultCollection/One/_git/Compute-Runtime-Tux true Linux Microsoft ================================================ FILE: AzureMonitorAgent/packaging.sh ================================================ #! /bin/bash set -e source agent.version usage() { local basename=`basename $0` echo "usage: ./$basename .{.deb, .rpm}> [path for zip output]" } input_path=$1 output_path=$2 PACKAGE_NAME="azuremonitor$AGENT_VERSION.zip" if [[ "$1" == "--help" ]]; then usage exit 0 elif [[ ! -d $input_path ]]; then echo "DEB/RPM files path '$input_path' not found" usage exit 1 fi if [[ "$output_path" == "" ]]; then output_path="../" fi # Packaging starts here cp -r ../Utils . cp ../Common/WALinuxAgent-2.0.16/waagent . cp -r ../LAD-AMA-Common/metrics_ext_utils . cp -r ../LAD-AMA-Common/telegraf_utils . cp -f ../Diagnostic/services/metrics-sourcer.service services/metrics-sourcer.service # cleanup packages, ext rm -rf packages MetricsExtensionBin azureotelcollector amaCoreAgentBin AstExtensionBin agentLauncherBin mdsdBin fluentBitBin tmp mkdir -p packages MetricsExtensionBin azureotelcollector amaCoreAgentBin AstExtensionBin agentLauncherBin mdsdBin fluentBitBin # copy shell bundle to packages/ cp $input_path/azuremonitoragent_$AGENT_VERSION* packages/ cp $input_path/azuremonitoragent-$AGENT_VERSION* packages/ # remove dynamic ssl packages rm -f packages/*dynamicssl* # validate HandlerManifest.json syntax jq empty < HandlerManifest.json mkdir -p tmp cp $input_path/azuremonitoragent_$AGENT_VERSION*dynamicssl_x86_64.deb tmp/ AMA_DEB_PACKAGE_NAME=$(find tmp/ -type f -name "azuremonitoragent_*x86_64.deb" -printf "%f\\n" | head -n 1) ar vx tmp/$AMA_DEB_PACKAGE_NAME --output=tmp tar xvf tmp/data.tar.gz -C tmp cp tmp/opt/microsoft/azuremonitoragent/bin/mdsd mdsdBin/mdsd_x86_64 cp tmp/opt/microsoft/azuremonitoragent/bin/mdsdmgr mdsdBin/mdsdmgr_x86_64 cp tmp/opt/microsoft/azuremonitoragent/bin/fluent-bit fluentBitBin/fluent-bit_x86_64 rm -rf tmp/ mkdir -p tmp cp $input_path/azuremonitoragent_$AGENT_VERSION*dynamicssl_aarch64.deb tmp/ AMA_DEB_PACKAGE_NAME=$(find tmp/ -type f -name "azuremonitoragent_*aarch64.deb" -printf "%f\\n" | head -n 1) ar vx tmp/$AMA_DEB_PACKAGE_NAME --output=tmp tar xvf tmp/data.tar.gz -C tmp cp tmp/opt/microsoft/azuremonitoragent/bin/mdsd mdsdBin/mdsd_aarch64 cp tmp/opt/microsoft/azuremonitoragent/bin/mdsdmgr mdsdBin/mdsdmgr_aarch64 cp tmp/opt/microsoft/azuremonitoragent/bin/fluent-bit fluentBitBin/fluent-bit_aarch64 rm -rf tmp/ cp $input_path/x86_64/metricsextension MetricsExtensionBin/metricsextension_x86_64 cp $input_path/aarch64/metricsextension MetricsExtensionBin/metricsextension_aarch64 cp $input_path/azureotelcollector/* azureotelcollector/ cp -r $input_path/AstExtension/* AstExtensionBin/ cp $input_path/x86_64/amacoreagent amaCoreAgentBin/amacoreagent_x86_64 cp $input_path/x86_64/liblz4x64.so amaCoreAgentBin/ #cp $input_path/x86_64/libgrpc_csharp_ext.x64.so amaCoreAgentBin/ cp $input_path/x86_64/agentlauncher agentLauncherBin/agentlauncher_x86_64 cp $input_path/metrics_troubleshooter.sh ama_tst/ cp $input_path/aarch64/amacoreagent amaCoreAgentBin/amacoreagent_aarch64 #cp $input_path/aarch64/libgrpc_csharp_ext.arm64.so amaCoreAgentBin/ cp $input_path/aarch64/agentlauncher agentLauncherBin/agentlauncher_aarch64 # make the shim.sh file executable chmod +x shim.sh # sync the file copy sync if [[ -f $output_path/$PACKAGE_NAME ]]; then echo "Removing existing $PACKAGE_NAME ..." rm -f $output_path/$PACKAGE_NAME fi echo "Packaging extension $PACKAGE_NAME to $output_path" excluded_files="agent.version packaging.sh apply_version.sh update_version.sh" zip -r $output_path/$PACKAGE_NAME * -x $excluded_files "./test/*" "./extension-test/*" "./references" "./tmp" # validate package size is within limits; these limits come from arc, ideally they are removed in the future max_uncompressed_size=$((1000 * 1024 * 1024)) max_compressed_size=$((500 * 1024 * 1024)) # easiest to validate by immediately unzipping versus trying to `du` with various exclusions unzip -d $output_path/unzipped $output_path/$PACKAGE_NAME uncompressed_size=$(du -sb $output_path/unzipped | cut -f1) compressed_size=$(du -sb $output_path/$PACKAGE_NAME | cut -f1) rm -rf $output_path/unzipped if [[ $uncompressed_size -gt $max_uncompressed_size ]]; then echo "Uncompressed size of $PACKAGE_NAME is $uncompressed_size bytes, which exceeds the limit of $max_uncompressed_size bytes" exit 1 fi if [[ $compressed_size -gt $max_compressed_size ]]; then echo "Compressed size of $PACKAGE_NAME is $compressed_size bytes, which exceeds the limit of $max_compressed_size bytes" exit 1 fi # cleanup newly added dir or files rm -rf Utils/ waagent ================================================ FILE: AzureMonitorAgent/references ================================================ Utils/ ================================================ FILE: AzureMonitorAgent/services/metrics-extension-cmv1.service ================================================ [Unit] Description=Metrics Extension service for Linux Agent metrics sourcing After=network.target [Service] ExecStart=%ME_BIN% -TokenSource MSI -Input influxdb_local -InfluxDbSocketPath %ME_INFLUX_SOCKET_FILE_PATH% -DataDirectory %ME_DATA_DIRECTORY% -LocalControlChannel -MonitoringAccount %ME_MONITORING_ACCOUNT% -LogLevel Error ExecReload=/bin/kill -HUP $MAINPID KillMode=control-group [Install] WantedBy=multi-user.target ================================================ FILE: AzureMonitorAgent/services/metrics-extension-cmv2.service ================================================ [Unit] Description=Metrics Extension service for Linux Agent metrics sourcing After=network.target [Service] Environment="OTLP_GRPC_HOST=127.0.0.1" Environment="OTLP_GRPC_PORT=4317" Environment="OTLP_GRPC_PROM_HOST=127.0.0.1" Environment="OTLP_GRPC_PROM_PORT=4316" EnvironmentFile=-/etc/metrics-extension.d/options.conf ExecStart=%ME_BIN% -TokenSource AMCS -ManagedIdentity %ME_MANAGED_IDENTITY% -DataDirectory %ME_DATA_DIRECTORY% -Input influxdb_local,otlp_grpc,otlp_grpc_prom -InfluxDbSocketPath %ME_INFLUX_SOCKET_FILE_PATH% -LogLevel Info -Logger Console -OperationEnvironment AMA-Linux/AMA_VERSION -ConfigOverrides "{\"otlp\":{\"endpoints\":[\"${OTLP_GRPC_PROM_HOST}:${OTLP_GRPC_PROM_PORT}\"]}}" ExecReload=/bin/kill -HUP $MAINPID KillMode=control-group User=azuremetricsext Group=azuremonitoragent RuntimeDirectory=azureotelcollector azuremetricsext RuntimeDirectoryMode=0755 [Install] WantedBy=multi-user.target ================================================ FILE: AzureMonitorAgent/services/metrics-extension-otlp.service ================================================ [Unit] Description=Metrics Extension service for Linux Agent metrics sourcing After=network.target [Service] Environment="OTLP_GRPC_HOST=127.0.0.1" Environment="OTLP_GRPC_PORT=4317" EnvironmentFile=-/etc/metrics-extension.d/options.conf ExecStart=%ME_BIN% -TokenSource AMCS -ManagedIdentity %ME_MANAGED_IDENTITY% -Input influxdb_local,otlp_grpc -InfluxDbSocketPath %ME_INFLUX_SOCKET_FILE_PATH% -LogLevel Info -Logger Console -OperationEnvironment AMA-Linux/AMA_VERSION ExecReload=/bin/kill -HUP $MAINPID KillMode=control-group [Install] WantedBy=multi-user.target ================================================ FILE: AzureMonitorAgent/shim.sh ================================================ #!/usr/bin/env bash # This is the main driver file for AMA extension. This file first checks if Python 3 or 2 is available on the VM # and if yes then uses that Python (if both are available then, default is set to python3) to run extension operations in agent.py # Control arguments passed to the shim are redirected to agent.py without validation. COMMAND="./agent.py" PYTHON="" ARG="$@" function find_python() { local python_exec_command=$1 if command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" elif command -v python2 >/dev/null 2>&1 ; then eval ${python_exec_command}="python2" elif command -v /usr/libexec/platform-python >/dev/null 2>&1 ; then # If a user-installed python isn't available, check for a platform-python. This is typically only used in RHEL 8.0. echo "User-installed python not found. Using /usr/libexec/platform-python as the python interpreter." eval ${python_exec_command}="/usr/libexec/platform-python" fi } find_python PYTHON if [ -z "$PYTHON" ] # If python is not installed, we will fail the install with the following error, requiring cx to have python pre-installed then echo "No Python interpreter found, which is an AMA extension dependency. Please install Python 3, or Python 2 if the former is unavailable." >&2 exit 52 # Missing Dependency else ${PYTHON} --version 2>&1 fi export NO_PROXY="169.254.169.254" PYTHONPATH=${PYTHONPATH} ${PYTHON} ${COMMAND} ${ARG} exit $? ================================================ FILE: AzureMonitorAgent/update_version.sh ================================================ #! /bin/bash set -x if [[ "$1" == "--help" ]]; then echo "update_version.sh " exit 0 fi UPDATE_DATE=`date +%Y%m%d` AGENT_VERSION=$1 MDSD_DEB_PACKAGE_NAME=$2 MDSD_RPM_PACKAGE_NAME=$3 if [[ "$AGENT_VERSION" == "" ]]; then echo "AGENT_VERSION version is empty" exit 1 fi if [[ "$MDSD_DEB_PACKAGE_NAME" == "" ]]; then echo "MDSD_DEB_PACKAGE_NAME is empty" exit 1 fi if [[ "$MDSD_RPM_PACKAGE_NAME" == "" ]]; then echo "MDSD_RPM_PACKAGE_NAME is empty" exit 1 fi sed -i "s/^AGENT_VERSION=.*$/AGENT_VERSION=$AGENT_VERSION/" agent.version sed -i "s/^MDSD_DEB_PACKAGE_NAME=.*$/MDSD_DEB_PACKAGE_NAME=$MDSD_DEB_PACKAGE_NAME/" agent.version sed -i "s/^MDSD_RPM_PACKAGE_NAME=.*$/MDSD_RPM_PACKAGE_NAME=$MDSD_RPM_PACKAGE_NAME/" agent.version sed -i "s/^AGENT_VERSION_DATE=.*$/AGENT_VERSION_DATE=$UPDATE_DATE/" agent.version ================================================ FILE: CODEOWNERS ================================================ # See https://help.github.com/articles/about-codeowners/ # for more info about CODEOWNERS file # It uses the same pattern rule for gitignore file # https://git-scm.com/docs/gitignore#_pattern_format # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, # the following owners will be requested for # review when someone opens a pull request. * @nkuchta @Azure/azure-agent-extensions # Azure Monitor Agent Extension /AzureMonitorAgent/ @Azure/geneva-linux-agents # CustomScript Extension /CustomScript/ @D1v38om83r @nkuchta @Azure/azure-agent-extensions # Diagnostics (LAD) Extension /Diagnostic/ @Azure/geneva-linux-agents # Utils for LAD/AMA metrics /LAD-AMA-Common/ @Azure/geneva-linux-agents # DSCForLinux Extension /DSC/ @Bhargava-Chary-Chollaty # OMS Agent Extension /OmsAgent/ @Azure/geneva-linux-agents # OpenCensus Translator Extension /opencensus-service/ @Azure/geneva-linux-agents # VMAccess Extension /VMAccess/ @D1v38om83r @nkuchta @Azure/azure-agent-extensions # VMBackup Extension /VMBackup/ @vityagi @mearvind @arisettisanjana @deveshjagwani # VMEncryption Extension /VMEncryption/ @vimish @ejarvi # WALinuxAgent /Common/ @D1v38om83r @nkuchta @Azure/walinuxagent @Azure/azure-agent-extensions /Utils/ @D1v38om83r @nkuchta @Azure/walinuxagent @Azure/azure-agent-extensions # Abandoned? # /AzureEnhancedMonitor/ # /OSPatching/ # /RDMAUpdate/ # /SampleExtension/ # /TestHandlerLinux/ # /docs/ # /registration-scripts/ # /script/ # /ui-extension-packages/ ================================================ FILE: Common/WALinuxAgent-2.0.14/waagent ================================================ #!/usr/bin/env python # # Windows Azure Linux Agent # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Implements parts of RFC 2131, 1541, 1497 and # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx # http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx import array import base64 import httplib import os import os.path import platform import pwd import re import shutil import socket import SocketServer import struct import string import subprocess import sys import tempfile import textwrap import threading import time import traceback import xml.dom.minidom import fcntl import inspect import zipfile import json import datetime import xml.sax.saxutils if not hasattr(subprocess,'check_output'): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output=check_output subprocess.CalledProcessError=CalledProcessError GuestAgentName = "WALinuxAgent" GuestAgentLongName = "Windows Azure Linux Agent" GuestAgentVersion = "WALinuxAgent-2.0.14" ProtocolVersion = "2012-11-30" #WARNING this value is used to confirm the correct fabric protocol. Config = None WaAgent = None DiskActivated = False Openssl = "openssl" Children = [] ExtensionChildren = [] VMM_STARTUP_SCRIPT_NAME='install' VMM_CONFIG_FILE_NAME='linuxosconfiguration.xml' global RulesFiles RulesFiles = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules" ] VarLibDhcpDirectories = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"] EtcDhcpClientConfFiles = ["/etc/dhcp/dhclient.conf", "/etc/dhcp3/dhclient.conf"] global LibDir LibDir = "/var/lib/waagent" global provisioned provisioned=False global provisionError provisionError=None HandlerStatusToAggStatus = {"installed":"Installing", "enabled":"Ready", "unintalled":"NotReady", "disabled":"NotReady"} WaagentConf = """\ # # Windows Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ext4 # Typically ext3 or ext4. FreeBSD images should use 'ufs2' here. ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Windows Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ README_FILENAME="DATALOSS_WARNING_README.txt" README_FILECONTENT="""\ WARNING: THIS IS A TEMPORARY DISK. Any data stored on this drive is SUBJECT TO LOSS and THERE IS NO WAY TO RECOVER IT. Please do not use this disk for storing any personal or application data. For additional details to please refer to the MSDN documentation at : http://msdn.microsoft.com/en-us/library/windowsazure/jj672979.aspx """ ############################################################ # BEGIN DISTRO CLASS DEFS ############################################################ ############################################################ # AbstractDistro ############################################################ class AbstractDistro(object): """ AbstractDistro defines a skeleton neccesary for a concrete Distro class. Generic methods and attributes are kept here, distribution specific attributes and behavior are to be placed in the concrete child named distroDistro, where distro is the string returned by calling python platform.linux_distribution()[0]. So for CentOS the derived class is called 'centosDistro'. """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux=None self.service_cmd='/usr/sbin/service' self.ssh_service_restart_option='restart' self.ssh_service_name='ssh' self.ssh_config_file='/etc/ssh/sshd_config' self.hostname_file_path='/etc/hostname' self.dhcp_client_name='dhclient' self.requiredDeps = [ 'route', 'shutdown', 'ssh-keygen', 'useradd', 'openssl', 'sfdisk', 'fdisk', 'mkfs', 'chpasswd', 'sed', 'grep', 'sudo', 'parted' ] self.init_script_file='/etc/init.d/waagent' self.agent_package_name='WALinuxAgent' self.fileBlackList = [ "/root/.bash_history", "/var/log/waagent.log",'/etc/resolv.conf' ] self.agent_files_to_uninstall = ["/etc/waagent.conf", "/etc/logrotate.d/waagent"] self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX_DEFAULT=' self.getpidcmd = 'pidof' self.mount_dvd_cmd = 'mount' self.sudoers_dir_base = '/etc' self.waagent_conf_file = WaagentConf self.shadow_file_mode=0600 self.dhcp_enabled = False def isSelinuxSystem(self): """ Checks and sets self.selinux = True if SELinux is available on system. """ if self.selinux == None: if Run("which getenforce",chk_err=False): self.selinux = False else: self.selinux = True return self.selinux def isSelinuxRunning(self): """ Calls shell command 'getenforce' and returns True if 'Enforcing'. """ if self.isSelinuxSystem(): return RunGetOutput("getenforce")[1].startswith("Enforcing") else: return False def setSelinuxEnforce(self,state): """ Calls shell command 'setenforce' with 'state' and returns resulting exit code. """ if self.isSelinuxSystem(): if state: s = '1' else: s='0' return Run("setenforce "+s) def setSelinuxContext(self,path,cn): """ Calls shell 'chcon' with 'path' and 'cn' context. Returns exit result. """ if self.isSelinuxSystem(): return Run('chcon ' + cn + ' ' + path) def setHostname(self,name): """ Shell call to hostname. Returns resulting exit code. """ return Run('hostname ' + name) def publishHostname(self,name): """ Set the contents of the hostname file to 'name'. Return 1 on failure. """ try: r=SetFileContents(self.hostname_file_path, name) for f in EtcDhcpClientConfFiles: if os.path.exists(f) and FindStringInFile(f,r'^[^#]*?send\s*host-name.*?(|gethostname[(,)])') == None : r=ReplaceFileContentsAtomic('/etc/dhcp/dhclient.conf', "send host-name \"" + name + "\";\n" + "\n".join(filter(lambda a: not a.startswith("send host-name"), GetFileContents('/etc/dhcp/dhclient.conf').split('\n')))) except: return 1 return r def installAgentServiceScriptFiles(self): """ Create the waagent support files for service installation. Called by registerAgentService() Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def registerAgentService(self): """ Calls installAgentService to create service files. Shell exec service registration commands. (e.g. chkconfig --add waagent) Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def uninstallAgentService(self): """ Call service subsystem to remove waagent script. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' start') def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' stop',False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_name + " " + self.ssh_service_restart_option retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def sshDeployPublicKey(self,fprint,path): """ Generic sshDeployPublicKey - over-ridden in some concrete Distro classes due to minor differences in openssl packages deployed """ error=0 SshPubKey = OvfEnv().OpensslToSsh(fprint) if SshPubKey != None: AppendFileContents(path, SshPubKey) else: Error("Failed: " + fprint + ".crt -> " + path) error = 1 return error def checkPackageInstalled(self,p): """ Query package database for prescence of an installed package. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def checkPackageUpdateable(self,p): """ Online check if updated package of walinuxagent is available. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def deleteRootPassword(self): """ Generic root password removal. """ filepath="/etc/shadow" ReplaceFileContentsAtomic(filepath,"root:*LOCK*:14600::::::\n" + "\n".join(filter(lambda a: not a.startswith("root:"),GetFileContents(filepath).split('\n')))) os.chmod(filepath,self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath,'system_u:object_r:shadow_t:s0') Log("Root password deleted.") return 0 def changePass(self,user,password): return RunSendStdin("chpasswd",(user + ":" + password + "\n"),use_shell=False) def load_ata_piix(self): return WaAgent.TryLoadAtapiix() def unload_ata_piix(self): """ Generic function to remove ata_piix.ko. """ return WaAgent.TryUnloadAtapiix() def deprovisionWarnUser(self): """ Generic user warnings used at deprovision. """ print("WARNING! Nameserver configuration in /etc/resolv.conf will be deleted.") def deprovisionDeleteFiles(self): """ Files to delete when VM is deprovisioned """ for a in VarLibDhcpDirectories: Run("rm -f " + a + "/*") # Clear LibDir, remove nameserver and root bash history for f in os.listdir(LibDir) + self.fileBlackList: try: os.remove(f) except: pass return 0 def uninstallDeleteFiles(self): """ Files to delete when agent is uninstalled. """ for f in self.agent_files_to_uninstall: try: os.remove(f) except: pass return 0 def checkDependencies(self): """ Generic dependency check. Return 1 unless all dependencies are satisfied. """ if self.checkPackageInstalled('NetworkManager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 try: m= __import__('pyasn1') except ImportError: Error(GuestAgentLongName + " requires python-pyasn1 for your Linux distribution.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1",chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self,buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot+'/etc'): os.mkdir(buildroot+'/etc') SetFileContents(buildroot+'/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot+'/etc/logrotate.d'): os.mkdir(buildroot+'/etc/logrotate.d') SetFileContents(buildroot+'/etc/logrotate.d/waagent', WaagentLogrotate) self.init_script_file=buildroot+self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def GetIpv4Address(self): """ Return the ip of the first active non-loopback interface. """ addr='' iface,addr=GetFirstActiveNetworkInterfaceNonLoopback() return addr def GetMacAddress(self): return GetMacAddress() def GetInterfaceName(self): return GetFirstActiveNetworkInterfaceNonLoopback()[0] def RestartInterface(self, iface): Run("ifdown " + iface + " && ifup " + iface) def CreateAccount(self,user, password, expiration, thumbprint): return CreateAccount(user, password, expiration, thumbprint) def DeleteAccount(self,user): return DeleteAccount(user) def ActivateResourceDisk(self): """ Format, mount, and if specified in the configuration set resource disk as swap. """ global DiskActivated format = Config.get("ResourceDisk.Format") if format == None or format.lower().startswith("n"): DiskActivated = True return device = DeviceForIdePort(1) if device == None: Error("ActivateResourceDisk: Unable to detect disk topology.") return device = "/dev/" + device mountlist = RunGetOutput("mount")[1] mountpoint = GetMountPoint(mountlist, device) if(mountpoint): Log("ActivateResourceDisk: " + device + "1 is already mounted.") else: mountpoint = Config.get("ResourceDisk.MountPoint") if mountpoint == None: mountpoint = "/mnt/resource" CreateDir(mountpoint, "root", 0755) fs = Config.get("ResourceDisk.Filesystem") if fs == None: fs = "ext3" partition = device + "1" #Check partition type Log("Detect GPT...") ret = RunGetOutput("parted {0} print".format(device)) if ret[0] == 0 and "gpt" in ret[1]: Log("GPT detected.") #GPT(Guid Partition Table) is used. #Get partitions. parts = filter(lambda x : re.match("^\s*[0-9]+", x), ret[1].split("\n")) #If there are more than 1 partitions, remove all partitions #and create a new one using the entire disk space. if len(parts) > 1: for i in range(1, len(parts) + 1): Run("parted {0} rm {1}".format(device, i)) Run("parted {0} mkpart primary 0% 100%".format(device)) Run("mkfs." + fs + " " + partition + " -F") else: existingFS = RunGetOutput("sfdisk -q -c " + device + " 1", chk_err=False)[1].rstrip() if existingFS == "7" and fs != "ntfs": Run("sfdisk -c " + device + " 1 83") Run("mkfs." + fs + " " + partition) if Run("mount " + partition + " " + mountpoint, chk_err=False): #If mount failed, try to format the partition and mount again Warn("Failed to mount resource disk. Retry mounting.") Run("mkfs." + fs + " " + partition + " -F") if Run("mount " + partition + " " + mountpoint): Error("ActivateResourceDisk: Failed to mount resource disk (" + partition + ").") return Log("Resource disk (" + partition + ") is mounted at " + mountpoint + " with fstype " + fs) #Create README file under the root of resource disk SetFileContents(os.path.join(mountpoint,README_FILENAME), README_FILECONTENT) DiskActivated = True #Create swap space swap = Config.get("ResourceDisk.EnableSwap") if swap == None or swap.lower().startswith("n"): return sizeKB = int(Config.get("ResourceDisk.SwapSizeMB")) * 1024 if os.path.isfile(mountpoint + "/swapfile") and os.path.getsize(mountpoint + "/swapfile") != (sizeKB * 1024): os.remove(mountpoint + "/swapfile") if not os.path.isfile(mountpoint + "/swapfile"): Run("dd if=/dev/zero of=" + mountpoint + "/swapfile bs=1024 count=" + str(sizeKB)) Run("mkswap " + mountpoint + "/swapfile") if not Run("swapon " + mountpoint + "/swapfile"): Log("Enabled " + str(sizeKB) + " KB of swap at " + mountpoint + "/swapfile") else: Error("ActivateResourceDisk: Failed to activate swap at " + mountpoint + "/swapfile") def Install(self): return Install() def mediaHasFilesystem(self,dsk): if len(dsk) == 0 : return False if Run("LC_ALL=C fdisk -l " + dsk + " | grep Disk"): return False return True def mountDVD(self,dvd,location): return RunGetOutput(self.mount_dvd_cmd + ' ' + dvd + ' ' + location) def GetHome(self): return GetHome() def getDhcpClientName(self): return self.dhcp_client_name def initScsiDiskTimeout(self): """ Set the SCSI disk timeout when the agent starts running """ self.setScsiDiskTimeout() def setScsiDiskTimeout(self): """ Iterate all SCSI disks(include hot-add) and set their timeout if their value are different from the OS.RootDeviceScsiTimeout """ try: scsiTimeout = Config.get("OS.RootDeviceScsiTimeout") for diskName in [disk for disk in os.listdir("/sys/block") if disk.startswith("sd")]: self.setBlockDeviceTimeout(diskName, scsiTimeout) except: pass def setBlockDeviceTimeout(self, device, timeout): """ Set SCSI disk timeout by set /sys/block/sd*/device/timeout """ if timeout != None and device: filePath = "/sys/block/" + device + "/device/timeout" if(GetFileContents(filePath).splitlines()[0].rstrip() != timeout): SetFileContents(filePath,timeout) Log("SetBlockDeviceTimeout: Update the device " + device + " with timeout " + timeout) def waitForSshHostKey(self, path): """ Provide a dummy waiting, since by default, ssh host key is created by waagent and the key should already been created. """ if(os.path.isfile(path)): return True else: Error("Can't find host key: {0}".format(path)) return False def isDHCPEnabled(self): return self.dhcp_enabled def stopDHCP(self): """ Stop the system DHCP client so that the agent can bind on its port. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('stopDHCP method missing') def startDHCP(self): """ Start the system DHCP client. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('startDHCP method missing') def translateCustomData(self, data): """ Translate the custom data from a Base64 encoding. Default to no-op. """ decodeCustomData = Config.get("Provisioning.DecodeCustomData") if decodeCustomData != None and decodeCustomData.lower().startswith("y"): return base64.b64decode(data) return data def getConfigurationPath(self): return "/etc/waagent.conf" def getProcessorCores(self): return int(RunGetOutput("grep 'processor.*:' /proc/cpuinfo |wc -l")[1]) def getTotalMemory(self): return int(RunGetOutput("grep MemTotal /proc/meminfo |awk '{print $2}'")[1])/1024 def getInterfaceNameByMac(self, mac): ret, output = RunGetOutput("ifconfig -a") if ret != 0: raise Exception("Failed to get network interface info") output = output.replace('\n', '') match = re.search(r"(eth\d).*(HWaddr|ether) {0}".format(mac), output, re.IGNORECASE) if match is None: raise Exception("Failed to get ifname with mac: {0}".format(mac)) output = match.group(0) eths = re.findall(r"eth\d", output) if eths is None or len(eths) == 0: raise Exception("Failed to get ifname with mac: {0}".format(mac)) return eths[-1] def configIpV4(self, ifName, addr, netmask=24): ret, output = RunGetOutput("ifconfig {0} up".format(ifName)) if ret != 0: raise Exception("Failed to bring up {0}: {1}".format(ifName, output)) ret, output = RunGetOutput("ifconfig {0} {1}/{2}".format(ifName, addr, netmask)) if ret != 0: raise Exception("Failed to config ipv4 for {0}: {1}".format(ifName, output)) ############################################################ # GentooDistro ############################################################ gentoo_init_file = """\ #!/sbin/runscript command=/usr/sbin/waagent pidfile=/var/run/waagent.pid command_args=-daemon command_background=true name="Windows Azure Linux Agent" depend() { need localmount use logger network after bootmisc modules } """ class gentooDistro(AbstractDistro): """ Gentoo distro concrete class """ def __init__(self): # super(gentooDistro,self).__init__() self.service_cmd='/sbin/service' self.ssh_service_name='sshd' self.hostname_file_path='/etc/conf.d/hostname' self.dhcp_client_name='dhcpcd' self.shadow_file_mode=0640 self.init_file=gentoo_init_file def publishHostname(self,name): try: if (os.path.isfile(self.hostname_file_path)): r=ReplaceFileContentsAtomic(self.hostname_file_path, "hostname=\"" + name + "\"\n" + "\n".join(filter(lambda a: not a.startswith("hostname="), GetFileContents(self.hostname_file_path).split("\n")))) except: return 1 return r def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0755) def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('rc-update add ' + self.agent_service_name + ' default') def uninstallAgentService(self): return Run('rc-update del ' + self.agent_service_name + ' default') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self,p): if Run('eix -I ^' + p + '$',chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self,p): if Run('eix -u ^' + p + '$',chk_err=False): return 0 else: return 1 def RestartInterface(self, iface): Run("/etc/init.d/net." + iface + " restart") ############################################################ # SuSEDistro ############################################################ suse_init_file = """\ #! /bin/sh # # Windows Azure Linux Agent sysV init script # # Copyright 2013 Microsoft Corporation # Copyright SUSE LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # /etc/init.d/waagent # # and symbolic link # # /usr/sbin/rcwaagent # # System startup script for the waagent # ### BEGIN INIT INFO # Provides: WindowsAzureLinuxAgent # Required-Start: $network sshd # Required-Stop: $network sshd # Default-Start: 3 5 # Default-Stop: 0 1 2 6 # Description: Start the WindowsAzureLinuxAgent ### END INIT INFO PYTHON=/usr/bin/python WAZD_BIN=/usr/sbin/waagent WAZD_CONF=/etc/waagent.conf WAZD_PIDFILE=/var/run/waagent.pid test -x "$WAZD_BIN" || { echo "$WAZD_BIN not installed"; exit 5; } test -e "$WAZD_CONF" || { echo "$WAZD_CONF not found"; exit 6; } . /etc/rc.status # First reset status of this service rc_reset # Return values acc. to LSB for all commands but status: # 0 - success # 1 - misc error # 2 - invalid or excess args # 3 - unimplemented feature (e.g. reload) # 4 - insufficient privilege # 5 - program not installed # 6 - program not configured # # Note that starting an already running service, stopping # or restarting a not-running service as well as the restart # with force-reload (in case signalling is not supported) are # considered a success. case "$1" in start) echo -n "Starting WindowsAzureLinuxAgent" ## Start daemon with startproc(8). If this fails ## the echo return value is set appropriate. startproc -f ${PYTHON} ${WAZD_BIN} -daemon rc_status -v ;; stop) echo -n "Shutting down WindowsAzureLinuxAgent" ## Stop daemon with killproc(8) and if this fails ## set echo the echo return value. killproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; try-restart) ## Stop the service and if this succeeds (i.e. the ## service was running before), start it again. $0 status >/dev/null && $0 restart rc_status ;; restart) ## Stop the service and regardless of whether it was ## running or not, start it again. $0 stop sleep 1 $0 start rc_status ;; force-reload|reload) rc_status ;; status) echo -n "Checking for service WindowsAzureLinuxAgent " ## Check status with checkproc(8), if process is running ## checkproc will return with exit status 0. checkproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; probe) ;; *) echo "Usage: $0 {start|stop|status|try-restart|restart|force-reload|reload}" exit 1 ;; esac rc_exit """ class SuSEDistro(AbstractDistro): """ SuSE Distro concrete class Put SuSE specific behavior here... """ def __init__(self): super(SuSEDistro,self).__init__() self.service_cmd='/sbin/service' self.ssh_service_name='sshd' self.kernel_boot_options_file='/boot/grub/menu.lst' self.hostname_file_path='/etc/HOSTNAME' self.requiredDeps += [ "/sbin/insserv" ] self.init_file=suse_init_file self.dhcp_client_name='dhcpcd' if ((DistInfo(fullname=1)[0] == 'SUSE Linux Enterprise Server' and DistInfo()[1] >= '12') or \ (DistInfo(fullname=1)[0] == 'openSUSE' and DistInfo()[1] >= '13.2')): self.dhcp_client_name='wickedd-dhcp4' self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' self.getpidcmd='pidof ' self.dhcp_enabled=True def checkPackageInstalled(self,p): if Run("rpm -q " + p,chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self,p): if Run("zypper list-updates | grep " + p,chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0744) except: pass def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('insserv ' + self.agent_service_name) def uninstallAgentService(self): return Run('insserv -r ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def startDHCP(self): Run("service " + self.dhcp_client_name + " start", chk_err=False) def stopDHCP(self): Run("service " + self.dhcp_client_name + " stop", chk_err=False) ############################################################ # redhatDistro ############################################################ redhat_init_file= """\ #!/bin/bash # # Init file for WindowsAzureLinuxAgent. # # chkconfig: 2345 60 80 # description: WindowsAzureLinuxAgent # # source function library . /etc/rc.d/init.d/functions RETVAL=0 FriendlyName="WindowsAzureLinuxAgent" WAZD_BIN=/usr/sbin/waagent start() { echo -n $"Starting $FriendlyName: " $WAZD_BIN -daemon & } stop() { echo -n $"Stopping $FriendlyName: " killproc -p /var/run/waagent.pid $WAZD_BIN RETVAL=$? echo return $RETVAL } case "$1" in start) start ;; stop) stop ;; restart) stop start ;; reload) ;; report) ;; status) status $WAZD_BIN RETVAL=$? ;; *) echo $"Usage: $0 {start|stop|restart|status}" RETVAL=1 esac exit $RETVAL """ class redhatDistro(AbstractDistro): """ Redhat Distro concrete class Put Redhat specific behavior here... """ def __init__(self): super(redhatDistro,self).__init__() self.service_cmd='/sbin/service' self.ssh_service_restart_option='condrestart' self.ssh_service_name='sshd' self.hostname_file_path= None if DistInfo()[1] < '7.0' else '/etc/hostname' self.init_file=redhat_init_file self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' def publishHostname(self,name): super(redhatDistro,self).publishHostname(name) if DistInfo()[1] < '7.0' : filepath = "/etc/sysconfig/network" if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "HOSTNAME=" + name + "\n" + "\n".join(filter(lambda a: not a.startswith("HOSTNAME"), GetFileContents(filepath).split('\n')))) ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join(filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n')))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0744) return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('chkconfig --add waagent') def uninstallAgentService(self): return Run('chkconfig --del ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self,p): if Run("yum list installed " + p,chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self,p): if Run("yum check-update | grep "+ p,chk_err=False): return 1 else: return 0 ############################################################ # centosDistro ############################################################ class centosDistro(redhatDistro): """ CentOS Distro concrete class Put CentOS specific behavior here... """ def __init__(self): super(centosDistro,self).__init__() ############################################################ # CoreOSDistro ############################################################ class CoreOSDistro(AbstractDistro): """ CoreOS Distro concrete class Put CoreOS specific behavior here... """ CORE_UID = 500 def __init__(self): super(CoreOSDistro,self).__init__() self.requiredDeps += [ "/usr/bin/systemctl" ] self.agent_service_name = 'waagent' self.init_script_file='/etc/systemd/system/waagent.service' self.fileBlackList.append("/etc/machine-id") self.dhcp_client_name='systemd-networkd' self.getpidcmd='pidof ' self.shadow_file_mode=0640 self.waagent_path='/usr/share/oem/bin' self.python_path='/usr/share/oem/python/bin' self.dhcp_enabled=True if 'PATH' in os.environ: os.environ['PATH'] = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: os.environ['PATH'] = self.python_path if 'PYTHONPATH' in os.environ: os.environ['PYTHONPATH'] = "{0}:{1}".format(os.environ['PYTHONPATH'], self.waagent_path) else: os.environ['PYTHONPATH'] = self.waagent_path def checkPackageInstalled(self,p): """ There is no package manager in CoreOS. Return 1 since it must be preinstalled. """ return 1 def checkDependencies(self): for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1",chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self,p): """ There is no package manager in CoreOS. Return 0 since it can't be updated via package. """ return 0 def startAgentService(self): return Run('systemctl start ' + self.agent_service_name) def stopAgentService(self): return Run('systemctl stop ' + self.agent_service_name) def restartSshService(self): """ SSH is socket activated on CoreOS. No need to restart it. """ return 0 def sshDeployPublicKey(self,fprint,path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else : return 0 def RestartInterface(self, iface): Run("systemctl restart systemd-networkd") def CreateAccount(self, user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin and userentry[2] != self.CORE_UID: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd --create-home --password '*' " + user if expiration != None: command += " --expiredate " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: RunSendStdin("chpasswd", user + ":" + password + "\n") try: if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def startDHCP(self): Run("systemctl start " + self.dhcp_client_name, chk_err=False) def stopDHCP(self): Run("systemctl stop " + self.dhcp_client_name, chk_err=False) def translateCustomData(self, data): return base64.b64decode(data) def getConfigurationPath(self): return "/usr/share/oem/waagent.conf" ############################################################ # debianDistro ############################################################ debian_init_file = """\ #!/bin/sh ### BEGIN INIT INFO # Provides: WindowsAzureLinuxAgent # Required-Start: $network $syslog # Required-Stop: $network $syslog # Should-Start: $network $syslog # Should-Stop: $network $syslog # Default-Start: 2 3 4 5 # Default-Stop: 0 1 6 # Short-Description: WindowsAzureLinuxAgent # Description: WindowsAzureLinuxAgent ### END INIT INFO . /lib/lsb/init-functions OPTIONS="-daemon" WAZD_BIN=/usr/sbin/waagent WAZD_PID=/var/run/waagent.pid case "$1" in start) log_begin_msg "Starting WindowsAzureLinuxAgent..." pid=$( pidofproc $WAZD_BIN ) if [ -n "$pid" ] ; then log_begin_msg "Already running." log_end_msg 0 exit 0 fi start-stop-daemon --start --quiet --oknodo --background --exec $WAZD_BIN -- $OPTIONS log_end_msg $? ;; stop) log_begin_msg "Stopping WindowsAzureLinuxAgent..." start-stop-daemon --stop --quiet --oknodo --pidfile $WAZD_PID ret=$? rm -f $WAZD_PID log_end_msg $ret ;; force-reload) $0 restart ;; restart) $0 stop $0 start ;; status) status_of_proc $WAZD_BIN && exit 0 || exit $? ;; *) log_success_msg "Usage: /etc/init.d/waagent {start|stop|force-reload|restart|status}" exit 1 ;; esac exit 0 """ class debianDistro(AbstractDistro): """ debian Distro concrete class Put debian specific behavior here... """ def __init__(self): super(debianDistro,self).__init__() self.requiredDeps += [ "/usr/sbin/update-rc.d" ] self.init_file=debian_init_file self.agent_package_name='walinuxagent' self.dhcp_client_name='dhclient' self.getpidcmd='pidof ' self.shadow_file_mode=0640 def checkPackageInstalled(self,p): """ Check that the package is installed. Return 1 if installed, 0 if not installed. This method of using dpkg-query allows wildcards to be present in the package name. """ if not Run("dpkg-query -W -f='${Status}\n' '" + p + "' | grep ' installed' 2>&1",chk_err=False): return 1 else: return 0 def checkDependencies(self): """ Debian dependency check. python-pyasn1 is NOT needed. Return 1 unless all dependencies are satisfied. NOTE: using network*manager will catch either package name in Ubuntu or debian. """ if self.checkPackageInstalled('network*manager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1",chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self,p): if Run("apt-get update ; apt-get upgrade -us | grep " + p,chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0744) except OSError, e: ErrorWithPrefix('installAgentServiceScriptFiles','Exception: '+str(e)+' occured creating ' + self.init_script_file) return 1 return 0 def registerAgentService(self): if self.installAgentServiceScriptFiles() == 0: return Run('update-rc.d waagent defaults') else : return 1 def uninstallAgentService(self): return Run('update-rc.d -f ' + self.agent_service_name + ' remove') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def sshDeployPublicKey(self,fprint,path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else : return 0 ############################################################ # KaliDistro - WIP # Functioning on Kali 1.1.0a so far ############################################################ class KaliDistro(debianDistro): """ Kali Distro concrete class Put Kali specific behavior here... """ def __init__(self): super(KaliDistro,self).__init__() ############################################################ # UbuntuDistro ############################################################ ubuntu_upstart_file = """\ #walinuxagent - start Windows Azure agent description "walinuxagent" author "Ben Howard " start on (filesystem and started rsyslog) pre-start script WALINUXAGENT_ENABLED=1 [ -r /etc/default/walinuxagent ] && . /etc/default/walinuxagent if [ "$WALINUXAGENT_ENABLED" != "1" ]; then exit 1 fi if [ ! -x /usr/sbin/waagent ]; then exit 1 fi #Load the udf module modprobe -b udf end script exec /usr/sbin/waagent -daemon """ class UbuntuDistro(debianDistro): """ Ubuntu Distro concrete class Put Ubuntu specific behavior here... """ def __init__(self): super(UbuntuDistro,self).__init__() self.init_script_file='/etc/init/waagent.conf' self.init_file=ubuntu_upstart_file self.fileBlackList = [ "/root/.bash_history", "/var/log/waagent.log"] self.dhcp_client_name=None self.getpidcmd='pidof ' def registerAgentService(self): return self.installAgentServiceScriptFiles() def uninstallAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 os.remove('/etc/init/' + self.agent_service_name + '.conf') def unregisterAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return self.stopAgentService() return self.uninstallAgentService() def deprovisionWarnUser(self): """ Ubuntu specific warning string from Deprovision. """ print("WARNING! Nameserver configuration in /etc/resolvconf/resolv.conf.d/{tail,original} will be deleted.") def deprovisionDeleteFiles(self): """ Ubuntu uses resolv.conf by default, so removing /etc/resolv.conf will break resolvconf. Therefore, we check to see if resolvconf is in use, and if so, we remove the resolvconf artifacts. """ if os.path.realpath('/etc/resolv.conf') != '/run/resolvconf/resolv.conf': Log("resolvconf is not configured. Removing /etc/resolv.conf") self.fileBlackList.append('/etc/resolv.conf') else: Log("resolvconf is enabled; leaving /etc/resolv.conf intact") resolvConfD = '/etc/resolvconf/resolv.conf.d/' self.fileBlackList.extend([resolvConfD + 'tail', resolvConfD + 'original']) for f in os.listdir(LibDir)+self.fileBlackList: try: os.remove(f) except: pass return 0 def getDhcpClientName(self): if self.dhcp_client_name != None : return self.dhcp_client_name if DistInfo()[1] == '12.04' : self.dhcp_client_name='dhclient3' else : self.dhcp_client_name='dhclient' return self.dhcp_client_name def waitForSshHostKey(self, path): """ Wait until the ssh host key is generated by cloud init. """ for retry in range(0, 10): if(os.path.isfile(path)): return True time.sleep(1) Error("Can't find host key: {0}".format(path)) return False ############################################################ # LinuxMintDistro ############################################################ class LinuxMintDistro(UbuntuDistro): """ LinuxMint Distro concrete class Put LinuxMint specific behavior here... """ def __init__(self): super(LinuxMintDistro,self).__init__() ############################################################ # fedoraDistro ############################################################ fedora_systemd_service = """\ [Unit] Description=Windows Azure Linux Agent After=network.target After=sshd.service ConditionFileIsExecutable=/usr/sbin/waagent ConditionPathExists=/etc/waagent.conf [Service] Type=simple ExecStart=/usr/sbin/waagent -daemon [Install] WantedBy=multi-user.target """ class fedoraDistro(redhatDistro): """ FedoraDistro concrete class Put Fedora specific behavior here... """ def __init__(self): super(fedoraDistro,self).__init__() self.service_cmd = '/usr/bin/systemctl' self.hostname_file_path = '/etc/hostname' self.init_script_file = '/usr/lib/systemd/system/' + self.agent_service_name + '.service' self.init_file = fedora_systemd_service self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX=' def publishHostname(self, name): SetFileContents(self.hostname_file_path, name + '\n') ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join(filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n')))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0644) return Run(self.service_cmd + ' daemon-reload') def registerAgentService(self): self.installAgentServiceScriptFiles() return Run(self.service_cmd + ' enable ' + self.agent_service_name) def uninstallAgentService(self): """ Call service subsystem to remove waagent script. """ return Run(self.service_cmd + ' disable ' + self.agent_service_name) def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' start ' + self.agent_service_name) def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' stop ' + self.agent_service_name, False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_restart_option + " " + self.ssh_service_name retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def checkPackageInstalled(self, p): """ Query package database for prescence of an installed package. """ import rpm ts = rpm.TransactionSet() rpms = ts.dbMatch(rpm.RPMTAG_PROVIDES, p) return bool(len(rpms) > 0) def deleteRootPassword(self): return Run("/sbin/usermod root -p '!!'") def packagedInstall(self,buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot+'/etc'): os.mkdir(buildroot+'/etc') SetFileContents(buildroot+'/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot+'/etc/logrotate.d'): os.mkdir(buildroot+'/etc/logrotate.d') SetFileContents(buildroot+'/etc/logrotate.d/WALinuxAgent', WaagentLogrotate) self.init_script_file=buildroot+self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def CreateAccount(self, user, password, expiration, thumbprint): super(fedoraDistro, self).CreateAccount(user, password, expiration, thumbprint) Run('/sbin/usermod ' + user + ' -G wheel') def DeleteAccount(self, user): Run('/sbin/usermod ' + user + ' -G ""') super(fedoraDistro, self).DeleteAccount(user) ############################################################ # FreeBSD ############################################################ FreeBSDWaagentConf = """\ # # Windows Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ufs2 # ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Windows Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ bsd_init_file="""\ #! /bin/sh # PROVIDE: waagent # REQUIRE: DAEMON cleanvar sshd # BEFORE: LOGIN # KEYWORD: nojail . /etc/rc.subr export PATH=$PATH:/usr/local/bin name="waagent" rcvar="waagent_enable" command="/usr/sbin/${name}" command_interpreter="/usr/local/bin/python" waagent_flags=" daemon &" pidfile="/var/run/waagent.pid" load_rc_config $name run_rc_command "$1" """ bsd_activate_resource_disk_txt="""\ #!/usr/bin/env python import os import sys import imp # waagent has no '.py' therefore create waagent module import manually. __name__='setupmain' #prevent waagent.__main__ from executing waagent=imp.load_source('waagent','/tmp/waagent') waagent.LoggerInit('/var/log/waagent.log','/dev/console') from waagent import RunGetOutput,Run Config=waagent.ConfigurationProvider() format = Config.get("ResourceDisk.Format") if format == None or format.lower().startswith("n"): sys.exit(0) device_base = 'da1' device = "/dev/" + device_base for entry in RunGetOutput("mount")[1].split(): if entry.startswith(device + "s1"): waagent.Log("ActivateResourceDisk: " + device + "s1 is already mounted.") sys.exit(0) mountpoint = Config.get("ResourceDisk.MountPoint") if mountpoint == None: mountpoint = "/mnt/resource" waagent.CreateDir(mountpoint, "root", 0755) fs = Config.get("ResourceDisk.Filesystem") if waagent.FreeBSDDistro().mediaHasFilesystem(device) == False : Run("newfs " + device + "s1") if Run("mount " + device + "s1 " + mountpoint): waagent.Error("ActivateResourceDisk: Failed to mount resource disk (" + device + "s1).") sys.exit(0) waagent.Log("Resource disk (" + device + "s1) is mounted at " + mountpoint + " with fstype " + fs) waagent.SetFileContents(os.path.join(mountpoint,waagent.README_FILENAME), waagent.README_FILECONTENT) swap = Config.get("ResourceDisk.EnableSwap") if swap == None or swap.lower().startswith("n"): sys.exit(0) sizeKB = int(Config.get("ResourceDisk.SwapSizeMB")) * 1024 if os.path.isfile(mountpoint + "/swapfile") and os.path.getsize(mountpoint + "/swapfile") != (sizeKB * 1024): os.remove(mountpoint + "/swapfile") if not os.path.isfile(mountpoint + "/swapfile"): Run("dd if=/dev/zero of=" + mountpoint + "/swapfile bs=1024 count=" + str(sizeKB)) if Run("mdconfig -a -t vnode -f " + mountpoint + "/swapfile -u 0"): waagent.Error("ActivateResourceDisk: Configuring swap - Failed to create md0") if not Run("swapon /dev/md0"): waagent.Log("Enabled " + str(sizeKB) + " KB of swap at " + mountpoint + "/swapfile") else: waagent.Error("ActivateResourceDisk: Failed to activate swap at " + mountpoint + "/swapfile") """ class FreeBSDDistro(AbstractDistro): """ """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ super(FreeBSDDistro,self).__init__() self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux=False self.ssh_service_name='sshd' self.ssh_config_file='/etc/ssh/sshd_config' self.hostname_file_path='/etc/hostname' self.dhcp_client_name='dhclient' self.requiredDeps = [ 'route', 'shutdown', 'ssh-keygen', 'pw' , 'openssl', 'fdisk', 'sed', 'grep' , 'sudo'] self.init_script_file='/etc/rc.d/waagent' self.init_file=bsd_init_file self.agent_package_name='WALinuxAgent' self.fileBlackList = [ "/root/.bash_history", "/var/log/waagent.log",'/etc/resolv.conf' ] self.agent_files_to_uninstall = ["/etc/waagent.conf"] self.grubKernelBootOptionsFile = '/boot/loader.conf' self.grubKernelBootOptionsLine = '' self.getpidcmd = 'pgrep -n' self.mount_dvd_cmd = 'dd bs=2048 count=33 skip=295 if=' # custom data max len is 64k self.sudoers_dir_base = '/usr/local/etc' self.waagent_conf_file = FreeBSDWaagentConf def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0777) AppendFileContents("/etc/rc.conf","waagent_enable='YES'\n") return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run("services_mkdb " + self.init_script_file) def sshDeployPublicKey(self,fprint,path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else : return 0 def deleteRootPassword(self): """ BSD root password removal. """ filepath="/etc/master.passwd" ReplaceStringInFile(filepath,r'root:.*?:','root::') #ReplaceFileContentsAtomic(filepath,"root:*LOCK*:14600::::::\n" # + "\n".join(filter(lambda a: not a.startswith("root:"),GetFileContents(filepath).split('\n')))) os.chmod(filepath,self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath,'system_u:object_r:shadow_t:s0') RunGetOutput("pwd_mkdb -u root /etc/master.passwd") Log("Root password deleted.") return 0 def changePass(self,user,password): return RunSendStdin("pw usermod " + user + " -h 0 ",password) def load_ata_piix(self): return 0 def unload_ata_piix(self): return 0 def checkDependencies(self): """ FreeBSD dependency check. Return 1 unless all dependencies are satisfied. """ for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1",chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self,buildroot): pass def GetInterfaceName(self): """ Return the ip of the active ethernet interface. """ iface,inet,mac=self.GetFreeBSDEthernetInfo() return iface def RestartInterface(self, iface): Run("service netif restart") def GetIpv4Address(self): """ Return the ip of the active ethernet interface. """ iface,inet,mac=self.GetFreeBSDEthernetInfo() return inet def GetMacAddress(self): """ Return the ip of the active ethernet interface. """ iface,inet,mac=self.GetFreeBSDEthernetInfo() l=mac.split(':') r=[] for i in l: r.append(string.atoi(i,16)) return r def GetFreeBSDEthernetInfo(self): """ There is no SIOCGIFCONF on freeBSD - just parse ifconfig. Returns strings: iface, inet4_addr, and mac or 'None,None,None' if unable to parse. We will sleep and retry as the network must be up. """ code,output=RunGetOutput("ifconfig",chk_err=False) Log(output) retries=10 cmd='ifconfig | grep -A2 -B2 ether | grep -B3 inet | grep -A4 UP ' code=1 while code > 0 : if code > 0 and retries == 0: Error("GetFreeBSDEthernetInfo - Failed to detect ethernet interface") return None, None, None code,output=RunGetOutput(cmd,chk_err=False) retries-=1 if code > 0 and retries > 0 : Log("GetFreeBSDEthernetInfo - Error: retry ethernet detection " + str(retries)) if retries == 9 : c,o=RunGetOutput("ifconfig | grep -A1 -B2 ether",chk_err=False) if c == 0: t=o.replace('\n',' ') t=t.split() i=t[0][:-1] Log(RunGetOutput('id')[1]) Run('dhclient '+i) time.sleep(10) j=output.replace('\n',' ') j=j.split() iface=j[0][:-1] for i in range(len(j)): if j[i] == 'inet' : inet=j[i+1] elif j[i] == 'ether' : mac=j[i+1] return iface, inet, mac def CreateAccount(self,user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "pw useradd " + user + " -m" if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: self.changePass(user,password) try: # for older distros create sudoers.d if not os.path.isdir(MyDistro.sudoers_dir_base+'/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir(MyDistro.sudoers_dir_base+'/sudoers.d') # add the include of sudoers.d to the /etc/sudoers SetFileContents(MyDistro.sudoers_dir_base+'/sudoers',GetFileContents(MyDistro.sudoers_dir_base+'/sudoers')+'\n#includedir ' + MyDistro.sudoers_dir_base + '/sudoers.d\n') if password == None: SetFileContents(MyDistro.sudoers_dir_base+"/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents(MyDistro.sudoers_dir_base+"/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod(MyDistro.sudoers_dir_base+"/sudoers.d/waagent", 0440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(self,user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") #Delete utmp to prevent error if we are the 'user' deleted pid = subprocess.Popen(['rmuser', '-y', user], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE).pid try: os.remove(MyDistro.sudoers_dir_base+"/sudoers.d/waagent") except: pass return def ActivateResourceDiskNoThread(self): """ Format, mount, and if specified in the configuration set resource disk as swap. """ global DiskActivated Run('cp /usr/sbin/waagent /tmp/') SetFileContents('/tmp/bsd_activate_resource_disk.py',bsd_activate_resource_disk_txt) Run('chmod +x /tmp/bsd_activate_resource_disk.py') pid = subprocess.Popen(["/tmp/bsd_activate_resource_disk.py", ""]).pid Log("Spawning bsd_activate_resource_disk.py") DiskActivated = True return def Install(self): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a) ) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", self.waagent_conf_file) if os.path.exists('/usr/local/etc/logrotate.d/'): SetFileContents("/usr/local/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split('\n'))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") #ApplyVNUMAWorkaround() return 0 def mediaHasFilesystem(self,dsk): if Run('LC_ALL=C fdisk -p ' + dsk + ' | grep "invalid fdisk partition table found" ',False): return False return True def mountDVD(self,dvd,location): #At this point we cannot read a joliet option udf DVD in freebsd10 - so we 'dd' it into our location retcode,out = RunGetOutput(self.mount_dvd_cmd + dvd + ' of=' + location + '/ovf-env.xml') if retcode != 0: return retcode,out ovfxml = (GetFileContents(location+"/ovf-env.xml",asbin=False)) if ord(ovfxml[0]) > 128 and ord(ovfxml[1]) > 128 and ord(ovfxml[2]) > 128 : ovfxml = ovfxml[3:] # BOM is not stripped. First three bytes are > 128 and not unicode chars so we ignore them. ovfxml = ovfxml.strip(chr(0x00)) ovfxml = "".join(filter(lambda x: ord(x)<128, ovfxml)) ovfxml = re.sub(r'.*\Z','',ovfxml,0,re.DOTALL) ovfxml += '' SetFileContents(location+"/ovf-env.xml", ovfxml) return retcode,out def GetHome(self): return '/home' def initScsiDiskTimeout(self): """ Set the SCSI disk timeout by updating the kernal config """ timeout = Config.get("OS.RootDeviceScsiTimeout") if timeout: Run("sysctl kern.cam.da.default_timeout=" + timeout) def setScsiDiskTimeout(self): return def setBlockDeviceTimeout(self, device, timeout): return def getProcessorCores(self): return int(RunGetOutput("sysctl hw.ncpu | awk '{print $2}'")[1]) def getTotalMemory(self): return int(RunGetOutput("sysctl hw.realmem | awk '{print $2}'")[1])/1024 ############################################################ # END DISTRO CLASS DEFS ############################################################ # This lets us index into a string or an array of integers transparently. def Ord(a): """ Allows indexing into a string or an array of integers transparently. Generic utility function. """ if type(a) == type("a"): a = ord(a) return a def IsLinux(): """ Returns True if platform is Linux. Generic utility function. """ return (platform.uname()[0] == "Linux") def GetLastPathElement(path): """ Similar to basename. Generic utility function. """ return path.rsplit('/', 1)[1] def GetFileContents(filepath,asbin=False): """ Read and return contents of 'filepath'. """ mode='r' if asbin: mode+='b' c=None try: with open(filepath, mode) as F : c=F.read() except IOError, e: ErrorWithPrefix('GetFileContents','Reading from file ' + filepath + ' Exception is ' + str(e)) return None return c def SetFileContents(filepath, contents): """ Write 'contents' to 'filepath'. """ if type(contents) == str : contents=contents.encode('latin-1', 'ignore') try: with open(filepath, "wb+") as F : F.write(contents) except IOError, e: ErrorWithPrefix('SetFileContents','Writing to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def AppendFileContents(filepath, contents): """ Append 'contents' to 'filepath'. """ if type(contents) == str : contents=contents.encode('latin-1') try: with open(filepath, "a+") as F : F.write(contents) except IOError, e: ErrorWithPrefix('AppendFileContents','Appending to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def ReplaceFileContentsAtomic(filepath, contents): """ Write 'contents' to 'filepath' by creating a temp file, and replacing original. """ handle, temp = tempfile.mkstemp(dir = os.path.dirname(filepath)) if type(contents) == str : contents=contents.encode('latin-1') try: os.write(handle, contents) except IOError, e: ErrorWithPrefix('ReplaceFileContentsAtomic','Writing to file ' + filepath + ' Exception is ' + str(e)) return None finally: os.close(handle) try: os.rename(temp, filepath) return None except IOError, e: ErrorWithPrefix('ReplaceFileContentsAtomic','Renaming ' + temp+ ' to ' + filepath + ' Exception is ' + str(e)) try: os.remove(filepath) except IOError, e: ErrorWithPrefix('ReplaceFileContentsAtomic','Removing '+ filepath + ' Exception is ' + str(e)) try: os.rename(temp,filepath) except IOError, e: ErrorWithPrefix('ReplaceFileContentsAtomic','Removing '+ filepath + ' Exception is ' + str(e)) return 1 return 0 def GetLineStartingWith(prefix, filepath): """ Return line from 'filepath' if the line startswith 'prefix' """ for line in GetFileContents(filepath).split('\n'): if line.startswith(prefix): return line return None def Run(cmd,chk_err=True): """ Calls RunGetOutput on 'cmd', returning only the return code. If chk_err=True then errors will be reported in the log. If chk_err=False then errors will be suppressed from the log. """ retcode,out=RunGetOutput(cmd,chk_err) return retcode def RunGetOutput(cmd,chk_err=True): """ Wrapper for subprocess.check_output. Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ LogIfVerbose(cmd) try: output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) except subprocess.CalledProcessError,e : if chk_err : Error('CalledProcessError. Error Code is ' + str(e.returncode) ) Error('CalledProcessError. Command string was ' + e.cmd ) Error('CalledProcessError. Command result was ' + (e.output[:-1]).decode('latin-1')) return e.returncode,e.output.decode('latin-1') return 0,output.decode('latin-1') def RunSendStdin(cmd,input,chk_err=True,use_shell=True): """ Wrapper for subprocess.Popen. Execute 'cmd', sending 'input' to STDIN of 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ LogIfVerbose(cmd+input) try: me=subprocess.Popen([cmd], shell=use_shell, stdin=subprocess.PIPE,stderr=subprocess.STDOUT,stdout=subprocess.PIPE) output=me.communicate(input) except OSError , e : if chk_err : Error('CalledProcessError. Error Code is ' + str(me.returncode) ) Error('CalledProcessError. Command string was ' + cmd ) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return 1,output[0].decode('latin-1') if me.returncode is not 0 and chk_err is True: Error('CalledProcessError. Error Code is ' + str(me.returncode) ) Error('CalledProcessError. Command string was ' + cmd ) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return me.returncode,output[0].decode('latin-1') def GetNodeTextData(a): """ Filter non-text nodes from DOM tree """ for b in a.childNodes: if b.nodeType == b.TEXT_NODE: return b.data def GetHome(): """ Attempt to guess the $HOME location. Return the path string. """ home = None try: home = GetLineStartingWith("HOME", "/etc/default/useradd").split('=')[1].strip() except: pass if (home == None) or (home.startswith("/") == False): home = "/home" return home def ChangeOwner(filepath, user): """ Lookup user. Attempt chown 'filepath' to 'user'. """ p = None try: p = pwd.getpwnam(user) except: pass if p != None: os.chown(filepath, p[2], p[3]) def CreateDir(dirpath, user, mode): """ Attempt os.makedirs, catch all exceptions. Call ChangeOwner afterwards. """ try: os.makedirs(dirpath, mode) except: pass ChangeOwner(dirpath, user) def CreateAccount(user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd -m " + user if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: RunSendStdin("chpasswd",(user + ":" + password + "\n")) try: # for older distros create sudoers.d if not os.path.isdir('/etc/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir('/etc/sudoers.d/') # add the include of sudoers.d to the /etc/sudoers SetFileContents('/etc/sudoers',GetFileContents('/etc/sudoers')+'\n#includedir /etc/sudoers.d\n') if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") #Delete utmp to prevent error if we are the 'user' deleted Run("userdel -f -r " + user) try: os.remove("/etc/sudoers.d/waagent") except: pass return def IsInRangeInclusive(a, low, high): """ Return True if 'a' in 'low' <= a >= 'high' """ return (a >= low and a <= high) def IsPrintable(ch): """ Return True if character is displayable. """ return IsInRangeInclusive(ch, Ord('A'), Ord('Z')) or IsInRangeInclusive(ch, Ord('a'), Ord('z')) or IsInRangeInclusive(ch, Ord('0'), Ord('9')) def HexDump(buffer, size): """ Return Hex formated dump of a 'buffer' of 'size'. """ if size < 0: size = len(buffer) result = "" for i in range(0, size): if (i % 16) == 0: result += "%06X: " % i byte = buffer[i] if type(byte) == str: byte = ord(byte.decode('latin1')) result += "%02X " % byte if (i & 15) == 7: result += " " if ((i + 1) % 16) == 0 or (i + 1) == size: j = i while ((j + 1) % 16) != 0: result += " " if (j & 7) == 7: result += " " j += 1 result += " " for j in range(i - (i % 16), i + 1): byte=buffer[j] if type(byte) == str: byte = ord(byte.decode('latin1')) k = '.' if IsPrintable(byte): k = chr(byte) result += k if (i + 1) != size: result += "\n" return result def SimpleLog(file_path,message): if not file_path or len(message) < 1: return t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) lines=re.sub(re.compile(r'^(.)',re.MULTILINE),t+r'\1',message) with open(file_path, "a") as F : lines = filter(lambda x : x in string.printable, lines) F.write(lines.encode('ascii','ignore') + "\n") class Logger(object): """ The Agent's logging assumptions are: For Log, and LogWithPrefix all messages are logged to the self.file_path and to the self.con_path. Setting either path parameter to None skips that log. If Verbose is enabled, messages calling the LogIfVerbose method will be logged to file_path yet not to con_path. Error and Warn messages are normal log messages with the 'ERROR:' or 'WARNING:' prefix added. """ def __init__(self,filepath,conpath,verbose=False): """ Construct an instance of Logger. """ self.file_path=filepath self.con_path=conpath self.verbose=verbose def ThrottleLog(self,counter): """ Log everything up to 10, every 10 up to 100, then every 100. """ return (counter < 10) or ((counter < 100) and ((counter % 10) == 0)) or ((counter % 100) == 0) def LogToFile(self,message): """ Write 'message' to logfile. """ if self.file_path: try: with open(self.file_path, "a") as F : message = filter(lambda x : x in string.printable, message) F.write(message.encode('ascii','ignore') + "\n") except IOError, e: print e pass def LogToCon(self,message): """ Write 'message' to /dev/console. This supports serial port logging if the /dev/console is redirected to ttys0 in kernel boot options. """ if self.con_path: try: with open(self.con_path, "w") as C : message = filter(lambda x : x in string.printable, message) C.write(message.encode('ascii','ignore') + "\n") except IOError, e: print e pass def Log(self,message): """ Standard Log function. Logs to self.file_path, and con_path """ self.LogWithPrefix("", message) def LogWithPrefix(self,prefix, message): """ Prefix each line of 'message' with current time+'prefix'. """ t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) t += prefix for line in message.split('\n'): line = t + line self.LogToFile(line) self.LogToCon(line) def NoLog(self,message): """ Don't Log. """ pass def LogIfVerbose(self,message): """ Only log 'message' if global Verbose is True. """ self.LogWithPrefixIfVerbose('',message) def LogWithPrefixIfVerbose(self,prefix, message): """ Only log 'message' if global Verbose is True. Prefix each line of 'message' with current time+'prefix'. """ if self.verbose == True: t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) t += prefix for line in message.split('\n'): line = t + line self.LogToFile(line) self.LogToCon(line) def Warn(self,message): """ Prepend the text "WARNING:" to the prefix for each line in 'message'. """ self.LogWithPrefix("WARNING:", message) def Error(self,message): """ Call ErrorWithPrefix(message). """ ErrorWithPrefix("", message) def ErrorWithPrefix(self,prefix, message): """ Prepend the text "ERROR:" to the prefix for each line in 'message'. Errors written to logfile, and /dev/console """ self.LogWithPrefix("ERROR:", message) def LoggerInit(log_file_path,log_con_path,verbose=False): """ Create log object and export its methods to global scope. """ global Log,LogWithPrefix,LogIfVerbose,LogWithPrefixIfVerbose,Error,ErrorWithPrefix,Warn,NoLog,ThrottleLog,myLogger l=Logger(log_file_path,log_con_path,verbose) Log,LogWithPrefix,LogIfVerbose,LogWithPrefixIfVerbose,Error,ErrorWithPrefix,Warn,NoLog,ThrottleLog,myLogger = l.Log,l.LogWithPrefix,l.LogIfVerbose,l.LogWithPrefixIfVerbose,l.Error,l.ErrorWithPrefix,l.Warn,l.NoLog,l.ThrottleLog,l def Linux_ioctl_GetInterfaceMac(ifname): """ Return the mac-address bound to the socket. """ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) info = fcntl.ioctl(s.fileno(), 0x8927, struct.pack('256s', (ifname[:15]+('\0'*241)).encode('latin-1'))) return ''.join(['%02X' % Ord(char) for char in info[18:24]]) def GetFirstActiveNetworkInterfaceNonLoopback(): """ Return the interface name, and ip addr of the first active non-loopback interface. """ iface='' expected=16 # how many devices should I expect... is_64bits = sys.maxsize > 2**32 struct_size=40 if is_64bits else 32 # for 64bit the size is 40 bytes, for 32bits it is 32 bytes. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) buff=array.array('B', b'\0' * (expected*struct_size)) retsize=(struct.unpack('iL', fcntl.ioctl(s.fileno(), 0x8912, struct.pack('iL',expected*struct_size,buff.buffer_info()[0]))))[0] if retsize == (expected*struct_size) : Warn('SIOCGIFCONF returned more than ' + str(expected) + ' up network interfaces.') s=buff.tostring() preferred_nic = Config.get("Network.Interface") for i in range(0,struct_size*expected,struct_size): iface=s[i:i+16].split(b'\0', 1)[0] if iface == b'lo': continue elif preferred_nic is None: break elif iface == preferred_nic: break return iface.decode('latin-1'), socket.inet_ntoa(s[i+20:i+24]) def GetIpv4Address(): """ Return the ip of the first active non-loopback interface. """ iface,addr=GetFirstActiveNetworkInterfaceNonLoopback() return addr def HexStringToByteArray(a): """ Return hex string packed into a binary struct. """ b = b"" for c in range(0, len(a) // 2): b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16)) return b def GetMacAddress(): """ Convienience function, returns mac addr bound to first non-loobback interface. """ ifname='' while len(ifname) < 2 : ifname=GetFirstActiveNetworkInterfaceNonLoopback()[0] a = Linux_ioctl_GetInterfaceMac(ifname) return HexStringToByteArray(a) def DeviceForIdePort(n): """ Return device name attached to ide port 'n'. """ if n > 3: return None g0 = "00000000" if n > 1: g0 = "00000001" n = n - 2 device = None path = "/sys/bus/vmbus/devices/" for vmbus in os.listdir(path): guid = GetFileContents(path + vmbus + "/device_id").lstrip('{').split('-') if guid[0] == g0 and guid[1] == "000" + str(n): for root, dirs, files in os.walk(path + vmbus): if root.endswith("/block"): device = dirs[0] break else : #older distros for d in dirs: if ':' in d and "block" == d.split(':')[0]: device = d.split(':')[1] break break return device class HttpResourceGoneError(Exception): pass class Util(object): """ Http communication class. Base of GoalState, and Agent classes. """ RetryWaitingInterval=10 def __init__(self): self.Endpoint = None def _ParseUrl(self, url): secure = False host = self.Endpoint path = url port = None #"http[s]://hostname[:port][/]" if url.startswith("http://"): url = url[7:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" elif url.startswith("https://"): secure = True url = url[8:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" if host is None: raise ValueError("Host is invalid:{0}".format(url)) if(":" in host): pos = host.rfind(":") port = int(host[pos + 1:]) host = host[0:pos] return host, port, secure, path def GetHttpProxy(self, secure): """ Get http_proxy and https_proxy from environment variables. Username and password is not supported now. """ host = Config.get("HttpProxy.Host") port = Config.get("HttpProxy.Port") return (host, port) def _HttpRequest(self, method, host, path, port=None, data=None, secure=False, headers=None, proxyHost=None, proxyPort=None): resp = None conn = None try: if secure: port = 443 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httplib.HTTPSConnection(proxyHost, proxyPort) conn.set_tunnel(host, port) #If proxy is used, full url is needed. path = "https://{0}:{1}{2}".format(host, port, path) else: conn = httplib.HTTPSConnection(host, port) else: port = 80 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httplib.HTTPConnection(proxyHost, proxyPort) #If proxy is used, full url is needed. path = "http://{0}:{1}{2}".format(host, port, path) else: conn = httplib.HTTPConnection(host, port) if headers == None: conn.request(method, path, data) else: conn.request(method, path, data, headers) resp = conn.getresponse() except httplib.HTTPException, e: Error('HTTPException {0}, args:{1}'.format(e, repr(e.args))) except IOError, e: Error('Socket IOError {0}, args:{1}'.format(e, repr(e.args))) return resp def HttpRequest(self, method, url, data=None, headers=None, maxRetry=3, chkProxy=False): """ Sending http request to server On error, sleep 10 and maxRetry times. Return the output buffer or None. """ LogIfVerbose("HTTP Req: {0} {1}".format(method, url)) LogIfVerbose("HTTP Req: Data={0}".format(data)) LogIfVerbose("HTTP Req: Header={0}".format(headers)) try: host, port, secure, path = self._ParseUrl(url) except ValueError, e: Error("Failed to parse url:{0}".format(url)) return None #Check proxy proxyHost, proxyPort = (None, None) if chkProxy: proxyHost, proxyPort = self.GetHttpProxy(secure) #If httplib module is not built with ssl support. Fallback to http if secure and not hasattr(httplib, "HTTPSConnection"): Warn("httplib is not built with ssl support") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) #If httplib module doesn't support https tunnelling. Fallback to http if secure and \ proxyHost is not None and \ proxyPort is not None and \ not hasattr(httplib.HTTPSConnection, "set_tunnel"): Warn("httplib doesn't support https tunnelling(new in python 2.7)") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) for retry in range(0, maxRetry): if resp is not None and \ (resp.status == httplib.OK or \ resp.status == httplib.CREATED or \ resp.status == httplib.ACCEPTED): return resp; if resp is not None and resp.status == httplib.GONE: raise HttpResourceGoneError("Http resource gone.") Error("Retry={0}".format(retry)) Error("HTTP Req: {0} {1}".format(method, url)) Error("HTTP Req: Data={0}".format(data)) Error("HTTP Req: Header={0}".format(headers)) if resp is None: Error("HTTP Err: response is empty.".format(retry)) else: Error("HTTP Err: Status={0}".format(resp.status)) Error("HTTP Err: Reason={0}".format(resp.reason)) Error("HTTP Err: Header={0}".format(resp.getheaders())) Error("HTTP Err: Body={0}".format(resp.read())) time.sleep(self.__class__.RetryWaitingInterval) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) return None def HttpGet(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("GET", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpHead(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("HEAD", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPost(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("POST", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPut(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("PUT", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpDelete(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("DELETE", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpGetWithoutHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url'. """ resp = self.HttpGet(url, headers=None, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpGetWithHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url' with x-ms-agent-name and x-ms-version headers. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpSecureGetWithHeaders(self, url, transportCert, maxRetry=3, chkProxy=False): """ Return output of get using ssl cert. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion, "x-ms-cipher-name": "DES_EDE3_CBC", "x-ms-guest-agent-public-x509-cert": transportCert }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpPostWithHeaders(self, url, data, maxRetry=3, chkProxy=False): headers = { "x-ms-agent-name": GuestAgentName, "Content-Type": "text/xml; charset=utf-8", "x-ms-version": ProtocolVersion } return self.HttpPost(url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) __StorageVersion="2014-02-14" def GetBlobType(url): restutil = Util() #Check blob type LogIfVerbose("Check blob type.") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) blobPropResp = restutil.HttpHead(url, { "x-ms-date" : timestamp, 'x-ms-version' : __StorageVersion }, chkProxy=True); blobType = None if blobPropResp is None: Error("Can't get status blob type.") return None blobType = blobPropResp.getheader("x-ms-blob-type") LogIfVerbose("Blob type={0}".format(blobType)) return blobType def PutBlockBlob(url, data): restutil = Util() LogIfVerbose("Upload block blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) ret = restutil.HttpPut(url, data, { "x-ms-date" : timestamp, "x-ms-blob-type" : "BlockBlob", "Content-Length": str(len(data)), "x-ms-version" : __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to upload block blob for status.") def PutPageBlob(url, data): restutil = Util() LogIfVerbose("Replace old page blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) #Align to 512 bytes pageBlobSize = ((len(data) + 511) / 512) * 512 ret = restutil.HttpPut(url, "", { "x-ms-date" : timestamp, "x-ms-blob-type" : "PageBlob", "Content-Length": "0", "x-ms-blob-content-length" : str(pageBlobSize), "x-ms-version" : __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to clean up page blob for status") return if url.index('?') < 0: url = "{0}?comp=page".format(url) else: url = "{0}&comp=page".format(url) LogIfVerbose("Upload page blob") pageMax = 4 * 1024 * 1024 #Max page size: 4MB start = 0 end = 0 while end < len(data): end = min(len(data), start + pageMax) contentSize = end - start #Align to 512 bytes pageEnd = ((end + 511) / 512) * 512 bufSize = pageEnd - start buf = bytearray(bufSize) buf[0 : contentSize] = data[start : end] ret = restutil.HttpPut(url, buffer(buf), { "x-ms-date" : timestamp, "x-ms-range" : "bytes={0}-{1}".format(start, pageEnd - 1), "x-ms-page-write" : "update", "x-ms-version" : __StorageVersion, "Content-Length": str(pageEnd - start) }, chkProxy=True) if ret is None: Error("Failed to upload page blob for status") return start = end def UploadStatusBlob(url, data): LogIfVerbose("Upload status blob") LogIfVerbose("Status={0}".format(data)) blobType = GetBlobType(url) if blobType == "BlockBlob": PutBlockBlob(url, data) elif blobType == "PageBlob": PutPageBlob(url, data) else: Error("Unknown blob type: {0}".format(blobType)) return None class TCPHandler(SocketServer.BaseRequestHandler): """ Callback object for LoadBalancerProbeServer. Recv and send LB probe messages. """ def __init__(self,lb_probe): super(TCPHandler,self).__init__() self.lb_probe=lb_probe def GetHttpDateTimeNow(self): """ Return formatted gmtime "Date: Fri, 25 Mar 2011 04:53:10 GMT" """ return time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) def handle(self): """ Log LB probe messages, read the socket buffer, send LB probe response back to server. """ self.lb_probe.ProbeCounter = (self.lb_probe.ProbeCounter + 1) % 1000000 log = [NoLog, LogIfVerbose][ThrottleLog(self.lb_probe.ProbeCounter)] strCounter = str(self.lb_probe.ProbeCounter) if self.lb_probe.ProbeCounter == 1: Log("Receiving LB probes.") log("Received LB probe # " + strCounter) self.request.recv(1024) self.request.send("HTTP/1.1 200 OK\r\nContent-Length: 2\r\nContent-Type: text/html\r\nDate: " + self.GetHttpDateTimeNow() + "\r\n\r\nOK") class LoadBalancerProbeServer(object): """ Threaded object to receive and send LB probe messages. Load Balancer messages but be recv'd by the load balancing server, or this node may be shut-down. """ def __init__(self, port): self.ProbeCounter = 0 self.server = SocketServer.TCPServer((self.get_ip(), port), TCPHandler) self.server_thread = threading.Thread(target = self.server.serve_forever) self.server_thread.setDaemon(True) self.server_thread.start() def shutdown(self): self.server.shutdown() def get_ip(self): for retry in range(1,6): ip = MyDistro.GetIpv4Address() if ip == None : Log("LoadBalancerProbeServer: GetIpv4Address() returned None, sleeping 10 before retry " + str(retry+1) ) time.sleep(10) else: return ip class ConfigurationProvider(object): """ Parse amd store key:values in waagent.conf """ def __init__(self, walaConfigFile): self.values = dict() if 'MyDistro' not in globals(): global MyDistro MyDistro = GetMyDistro() if walaConfigFile is None: walaConfigFile = MyDistro.getConfigurationPath() if os.path.isfile(walaConfigFile) == False: raise Exception("Missing configuration in {0}".format(walaConfigFile)) try: for line in GetFileContents(walaConfigFile).split('\n'): if not line.startswith("#") and "=" in line: parts = line.split()[0].split('=') value = parts[1].strip("\" ") if value != "None": self.values[parts[0]] = value else: self.values[parts[0]] = None except: Error("Unable to parse {0}".format(walaConfigFile)) raise return def get(self, key): return self.values.get(key) class EnvMonitor(object): """ Montor changes to dhcp and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ def __init__(self): self.shutdown = False self.HostName = socket.gethostname() self.server_thread = threading.Thread(target = self.monitor) self.server_thread.setDaemon(True) self.server_thread.start() self.published = False def monitor(self): """ Monitor dhcp client pid and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ publish = Config.get("Provisioning.MonitorHostName") dhcpcmd = MyDistro.getpidcmd+ ' ' + MyDistro.getDhcpClientName() dhcppid = RunGetOutput(dhcpcmd)[1] while not self.shutdown: for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Log("EnvMonitor: Moved " + a + " -> " + LibDir) MyDistro.setScsiDiskTimeout() if publish != None and publish.lower().startswith("y"): try: if socket.gethostname() != self.HostName: Log("EnvMonitor: Detected host name change: " + self.HostName + " -> " + socket.gethostname()) self.HostName = socket.gethostname() WaAgent.UpdateAndPublishHostName(self.HostName) dhcppid = RunGetOutput(dhcpcmd)[1] self.published = True except: pass else: self.published = True pid = "" if not os.path.isdir("/proc/" + dhcppid.strip()): pid = RunGetOutput(dhcpcmd)[1] if pid != "" and pid != dhcppid: Log("EnvMonitor: Detected dhcp client restart. Restoring routing table.") WaAgent.RestoreRoutes() dhcppid = pid for child in Children: if child.poll() != None: Children.remove(child) time.sleep(5) def SetHostName(self, name): """ Generic call to MyDistro.setHostname(name). Complian to Log on error. """ if socket.gethostname() == name: self.published = True elif MyDistro.setHostname(name): Error("Error: SetHostName: Cannot set hostname to " + name) return ("Error: SetHostName: Cannot set hostname to " + name) def IsHostnamePublished(self): """ Return self.published """ return self.published def ShutdownService(self): """ Stop server comminucation and join the thread to main thread. """ self.shutdown = True self.server_thread.join() class Certificates(object): """ Object containing certificates of host and provisioned user. Parses and splits certificates into files. """ # # 2010-12-15 # 2 # Pkcs7BlobWithPfxContents # MIILTAY... # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset the Role, Incarnation """ self.Incarnation = None self.Role = None def Parse(self, xmlText): """ Parse multiple certificates into seperate files. """ self.reinitialize() SetFileContents("Certificates.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in [ "CertificateFile", "Version", "Incarnation", "Format", "Data", ]: if not dom.getElementsByTagName(a): Error("Certificates.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "CertificateFile": Error("Certificates.Parse: root not CertificateFile") return None SetFileContents("Certificates.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"Certificates.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"Certificates.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + GetNodeTextData(dom.getElementsByTagName("Data")[0])) if Run(Openssl + " cms -decrypt -in Certificates.p7m -inkey TransportPrivate.pem -recip TransportCert.pem | " + Openssl + " pkcs12 -nodes -password pass: -out Certificates.pem"): Error("Certificates.Parse: Failed to extract certificates from CMS message.") return self # There may be multiple certificates in this package. Split them. file = open("Certificates.pem") pindex = 1 cindex = 1 output = open("temp.pem", "w") for line in file.readlines(): output.write(line) if re.match(r'[-]+END .*?(KEY|CERTIFICATE)[-]+$',line): output.close() if re.match(r'[-]+END .*?KEY[-]+$',line): os.rename("temp.pem", str(pindex) + ".prv") pindex += 1 else: os.rename("temp.pem", str(cindex) + ".crt") cindex += 1 output = open("temp.pem", "w") output.close() os.remove("temp.pem") keys = dict() index = 1 filename = str(index) + ".crt" while os.path.isfile(filename): thumbprint = (RunGetOutput(Openssl + " x509 -in " + filename + " -fingerprint -noout")[1]).rstrip().split('=')[1].replace(':', '').upper() pubkey=RunGetOutput(Openssl + " x509 -in " + filename + " -pubkey -noout")[1] keys[pubkey] = thumbprint os.rename(filename, thumbprint + ".crt") os.chmod(thumbprint + ".crt", 0600) MyDistro.setSelinuxContext(thumbprint + '.crt','unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".crt" index = 1 filename = str(index) + ".prv" while os.path.isfile(filename): pubkey = RunGetOutput(Openssl + " rsa -in " + filename + " -pubout 2> /dev/null ")[1] os.rename(filename, keys[pubkey] + ".prv") os.chmod(keys[pubkey] + ".prv", 0600) MyDistro.setSelinuxContext( keys[pubkey] + '.prv','unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".prv" return self class SharedConfig(object): """ Parse role endpoint server and goal state config. """ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.RdmaMacAddress = None self.RdmaIPv4Address = None self.xmlText = None def Parse(self, xmlText): """ Parse and write configuration to file SharedConfig.xml. """ LogIfVerbose(xmlText) self.reinitialize() self.xmlText = xmlText dom = xml.dom.minidom.parseString(xmlText) for a in [ "SharedConfig", "Deployment", "Service", "ServiceInstance", "Incarnation", "Role", ]: if not dom.getElementsByTagName(a): Error("SharedConfig.Parse: Missing " + a) node = dom.childNodes[0] if node.localName != "SharedConfig": Error("SharedConfig.Parse: root not SharedConfig") nodes = dom.getElementsByTagName("Instance") if nodes is not None and len(nodes) != 0: node = nodes[0] if node.hasAttribute("rdmaMacAddress"): addr = node.getAttribute("rdmaMacAddress") self.RdmaMacAddress = addr[0:2] for i in range(1, 6): self.RdmaMacAddress += ":" + addr[2 * i : 2 *i + 2] if node.hasAttribute("rdmaIPv4Address"): self.RdmaIPv4Address = node.getAttribute("rdmaIPv4Address") return self def Save(self): LogIfVerbose("Save SharedConfig.xml") SetFileContents("SharedConfig.xml", self.xmlText) def InvokeTopologyConsumer(self): program = Config.get("Role.TopologyConsumer") if program != None: try: Children.append(subprocess.Popen([program, LibDir + "/SharedConfig.xml"])) except OSError, e : ErrorWithPrefix('Agent.Run','Exception: '+ str(e) +' occured launching ' + program ) def Process(self): global rdma_configured if not rdma_configured and self.RdmaMacAddress is not None and self.RdmaIPv4Address is not None: handler = RdmaHandler(self.RdmaMacAddress, self.RdmaIPv4Address) handler.start() rdma_configured = True self.InvokeTopologyConsumer() rdma_configured = False class RdmaError(Exception): pass class RdmaHandler(object): """ Handle rdma configuration. """ def __init__(self, mac, ip_addr, dev="/dev/hvnd_rdma", dat_conf_files=['/etc/dat.conf', '/etc/rdma/dat.conf', '/usr/local/etc/dat.conf']): self.mac = mac self.ip_addr = ip_addr self.dev = dev self.dat_conf_files = dat_conf_files self.data = ('rdmaMacAddress="{0}" rdmaIPv4Address="{1}"' '').format(self.mac, self.ip_addr) def start(self): """ Start a new thread to process rdma """ threading.Thread(target=self.process).start() def process(self): try: self.set_dat_conf() self.set_rdma_dev() self.set_rdma_ip() except RdmaError as e: Error("Failed to config rdma device: {0}".format(e)) def set_dat_conf(self): """ Agent needs to search all possible locations for dat.conf """ Log("Set dat.conf") for dat_conf_file in self.dat_conf_files: if not os.path.isfile(dat_conf_file): continue try: self.write_dat_conf(dat_conf_file) except IOError as e: raise RdmaError("Failed to write to dat.conf: {0}".format(e)) def write_dat_conf(self, dat_conf_file): Log("Write config to {0}".format(dat_conf_file)) old = ("ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 " "dapl.2.0 \"\S+ 0\"") new = ("ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 " "dapl.2.0 \"{0} 0\"").format(self.ip_addr) lines = GetFileContents(dat_conf_file) lines = re.sub(old, new, lines) SetFileContents(dat_conf_file, lines) def set_rdma_dev(self): """ Write config string to /dev/hvnd_rdma """ Log("Set /dev/hvnd_rdma") self.wait_rdma_dev() self.write_rdma_dev_conf() def write_rdma_dev_conf(self): Log("Write rdma config to {0}: {1}".format(self.dev, self.data)) try: with open(self.dev, "w") as c: c.write(self.data) except IOError, e: raise RdmaError("Error writing {0}, {1}".format(self.dev, e)) def wait_rdma_dev(self): Log("Wait for /dev/hvnd_rdma") retry = 0 while retry < 120: if os.path.exists(self.dev): return time.sleep(1) retry += 1 raise RdmaError("The device doesn't show up in 120 seconds") def set_rdma_ip(self): Log("Set ip addr for rdma") try: if_name = MyDistro.getInterfaceNameByMac(self.mac) #Azure is using 12 bits network mask for infiniband. MyDistro.configIpV4(if_name, self.ip_addr, 12) except Exception as e: raise RdmaError("Failed to config rdma device: {0}".format(e)) class ExtensionsConfig(object): """ Parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. Install if true, remove if it is set to false. """ # # # # # # # {"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1", #"protectedSettings":"MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR #Xh0ZW5zaW9ucwIQZi7dw+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6 #tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/X #v1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqh #kiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]} # # #https://ostcextensions.blob.core.test-cint.azure-test.net/vhds/eg-plugin7-vm.eg-plugin7-vm.eg-plugin7-vm.status?sr=b&sp=rw& #se=9999-01-01&sk=key1&sv=2012-02-12&sig=wRUIDN1x2GC06FWaetBP9sjjifOWvRzS2y2XBB4qoBU%3D def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.Extensions = None self.Plugins = None self.Util = None def Parse(self, xmlText): """ Write configuration to file ExtensionsConfig.xml. Log plugin specific activity to /var/log/azure/.//CommandExecution.log. If state is enabled: if the plugin is installed: if the new plugin's version is higher if DisallowMajorVersionUpgrade is false or if true, the version is a minor version do upgrade: download the new archive do the updateCommand. disable the old plugin and remove enable the new plugin if the new plugin's version is the same or lower: create the new .settings file from the configuration received do the enableCommand if the plugin is not installed: download/unpack archive and call the installCommand/Enable if state is disabled: call disableCommand if state is uninstall: call uninstallCommand remove old plugin directory. """ self.reinitialize() self.Util=Util() dom = xml.dom.minidom.parseString(xmlText) LogIfVerbose(xmlText) self.plugin_log_dir='/var/log/azure' if not os.path.exists(self.plugin_log_dir): os.mkdir(self.plugin_log_dir) try: self.Extensions=dom.getElementsByTagName("Extensions") pg = dom.getElementsByTagName("Plugins") if len(pg) > 0: self.Plugins = pg[0].getElementsByTagName("Plugin") else: self.Plugins = [] incarnation=self.Extensions[0].getAttribute("goalStateIncarnation") SetFileContents('ExtensionsConfig.'+incarnation+'.xml', xmlText) except Exception, e: Error('ERROR: Error parsing ExtensionsConfig: {0}.'.format(e)) return None for p in self.Plugins: if len(p.getAttribute("location"))<1: # this plugin is inside the PluginSettings continue p.setAttribute('restricted','false') previous_version = None version=p.getAttribute("version") name=p.getAttribute("name") plog_dir=self.plugin_log_dir+'/'+name +'/'+ version if not os.path.exists(plog_dir): os.makedirs(plog_dir) p.plugin_log=plog_dir+'/CommandExecution.log' handler=name + '-' + version if p.getAttribute("isJson") != 'true': Error("Plugin " + name+" version: " +version+" is not a JSON Extension. Skipping.") continue Log("Found Plugin: " + name + ' version: ' + version) if p.getAttribute("state") == 'disabled' or p.getAttribute("state") == 'uninstall': #disable zip_dir=LibDir+"/" + name + '-' + version mfile=None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile=os.path.join(root,f) if mfile != None: break if mfile == None : Error('HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata',manifest) if self.launchCommand(p.plugin_log,name,version,'disableCommand') == None : self.SetHandlerState(handler, 'Enabled') Error('Unable to disable '+name) SimpleLog(p.plugin_log,'ERROR: Unable to disable '+name) else : self.SetHandlerState(handler, 'Disabled') Log(name+' is disabled') SimpleLog(p.plugin_log,name+' is disabled') # uninstall if needed if p.getAttribute("state") == 'uninstall': if self.launchCommand(p.plugin_log,name,version,'uninstallCommand') == None : self.SetHandlerState(handler, 'Installed') Error('Unable to uninstall '+name) SimpleLog(p.plugin_log,'Unable to uninstall '+name) else : self.SetHandlerState(handler, 'NotInstalled') Log(name+' uninstallCommand completed .') # remove the plugin Run('rm -rf ' + LibDir + '/' + name +'-'+ version + '*') Log(name +'-'+ version + ' extension files deleted.') SimpleLog(p.plugin_log,name +'-'+ version + ' extension files deleted.') continue # state is enabled # if the same plugin exists and the version is newer or # does not exist then download and unzip the new plugin plg_dir=None for root, dirs, files in os.walk(LibDir): for d in dirs: if name in d: plg_dir=os.path.join(root,d) if plg_dir != None: break if plg_dir != None : previous_version=plg_dir.rsplit('-')[-1] if plg_dir == None or version > previous_version : location=p.getAttribute("location") Log("Downloading plugin manifest: " + name + " from " + location) SimpleLog(p.plugin_log,"Downloading plugin manifest: " + name + " from " + location) self.Util.Endpoint=location.split('/')[2] Log("Plugin server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log,"Plugin server is: " + self.Util.Endpoint) manifest=self.Util.HttpGetWithoutHeaders(location, chkProxy=True) if manifest == None: Error("Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") SimpleLog(p.plugin_log,"Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") failoverlocation=p.getAttribute("failoverlocation") self.Util.Endpoint=failoverlocation.split('/')[2] Log("Plugin failover server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log,"Plugin failover server is: " + self.Util.Endpoint) manifest=self.Util.HttpGetWithoutHeaders(failoverlocation, chkProxy=True) #if failoverlocation also fail what to do then? if manifest == None: AddExtensionEvent(name,WALAEventOperation.Download,False,0,version,"Download mainfest fail "+failoverlocation) Log("Plugin manifest " + name + " downloading failed from failover location.") SimpleLog(p.plugin_log,"Plugin manifest " + name + " downloading failed from failover location.") filepath=LibDir+"/" + name + '.' + incarnation + '.manifest' if os.path.splitext(location)[-1] == '.xml' : #if this is an xml file we may have a BOM if ord(manifest[0]) > 128 and ord(manifest[1]) > 128 and ord(manifest[2]) > 128: manifest=manifest[3:] SetFileContents(filepath,manifest) #Get the bundle url from the manifest p.setAttribute('manifestdata',manifest) man_dom = xml.dom.minidom.parseString(manifest) bundle_uri = "" for mp in man_dom.getElementsByTagName("Plugin"): if GetNodeTextData(mp.getElementsByTagName("Version")[0]) == version: bundle_uri = GetNodeTextData(mp.getElementsByTagName("Uri")[0]) break if len(mp.getElementsByTagName("DisallowMajorVersionUpgrade")): if GetNodeTextData(mp.getElementsByTagName("DisallowMajorVersionUpgrade")[0]) == 'true' and previous_version !=None and previous_version.split('.')[0] != version.split('.')[0] : Log('DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') SimpleLog(p.plugin_log,'DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') p.setAttribute('restricted','true') continue if len(bundle_uri) < 1 : Error("Unable to fetch Bundle URI from manifest for " + name + " v " + version) SimpleLog(p.plugin_log,"Unable to fetch Bundle URI from manifest for " + name + " v " + version) continue Log("Bundle URI = " + bundle_uri) SimpleLog(p.plugin_log,"Bundle URI = " + bundle_uri) # Download the zipfile archive and save as '.zip' bundle=self.Util.HttpGetWithoutHeaders(bundle_uri, chkProxy=True) if bundle == None: AddExtensionEvent(name,WALAEventOperation.Download,True,0,version,"Download zip fail "+bundle_uri) Error("Unable to download plugin bundle" + bundle_uri ) SimpleLog(p.plugin_log,"Unable to download plugin bundle" + bundle_uri ) continue AddExtensionEvent(name,WALAEventOperation.Download,True,0,version,"Download Success") b=bytearray(bundle) filepath=LibDir+"/" + os.path.basename(bundle_uri) + '.zip' SetFileContents(filepath,b) Log("Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) SimpleLog(p.plugin_log,"Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) # unpack the archive z=zipfile.ZipFile(filepath) zip_dir=LibDir+"/" + name + '-' + version z.extractall(zip_dir) Log('Extracted ' + bundle_uri + ' to ' + zip_dir) SimpleLog(p.plugin_log,'Extracted ' + bundle_uri + ' to ' + zip_dir) # zip no file perms in .zip so set all the scripts to +x Run( "find " + zip_dir +" -type f | xargs chmod u+x ") #write out the base64 config data so the plugin can process it. mfile=None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile=os.path.join(root,f) if mfile != None: break if mfile == None : Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log,'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata',manifest) # create the status and config dirs Run('mkdir -p ' + root + '/status') Run('mkdir -p ' + root + '/config') # write out the configuration data to goalStateIncarnation.settings file in the config path. config='' seqNo='0' if len(dom.getElementsByTagName("PluginSettings")) != 0 : pslist=dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log,"Found RuntimeSettings for " + name + " V " + version) config=GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo=ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Log("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log,"No RuntimeSettings for " + name + " V " + version) SetFileContents(root +"/config/" + seqNo +".settings", config ) #create HandlerEnvironment.json handler_env='[{ "name": "'+name+'", "seqNo": "'+seqNo+'", "version": 1.0, "handlerEnvironment": { "logFolder": "'+os.path.dirname(p.plugin_log)+'", "configFolder": "' + root + '/config", "statusFolder": "' + root + '/status", "heartbeatFile": "'+ root + '/heartbeat.log"}}]' SetFileContents(root+'/HandlerEnvironment.json',handler_env) self.SetHandlerState(handler, 'NotInstalled') cmd = '' getcmd='installCommand' if plg_dir != None and previous_version != None and version > previous_version : previous_handler=name+'-'+previous_version if self.GetHandlerState(previous_handler) != 'NotInstalled': getcmd='updateCommand' # disable the old plugin if it exists if self.launchCommand(p.plugin_log,name,previous_version,'disableCommand') == None : self.SetHandlerState(previous_handler, 'Enabled') Error('Unable to disable old plugin '+name+' version ' + previous_version) SimpleLog(p.plugin_log,'Unable to disable old plugin '+name+' version ' + previous_version) else : self.SetHandlerState(previous_handler, 'Disabled') Log(name+' version ' + previous_version + ' is disabled') SimpleLog(p.plugin_log,name+' version ' + previous_version + ' is disabled') isupgradeSuccess = True if getcmd=='updateCommand': if self.launchCommand(p.plugin_log,name,version,getcmd,previous_version) == None : Error('Update failed for '+name+'-'+version) SimpleLog(p.plugin_log,'Update failed for '+name+'-'+version) isupgradeSuccess=False else : Log('Update complete'+name+'-'+version) SimpleLog(p.plugin_log,'Update complete'+name+'-'+version) # if we updated - call unistall for the old plugin if self.launchCommand(p.plugin_log,name,previous_version,'uninstallCommand') == None : self.SetHandlerState(previous_handler, 'Installed') Error('Uninstall failed for '+name+'-'+previous_version) SimpleLog(p.plugin_log,'Uninstall failed for '+name+'-'+previous_version) isupgradeSuccess=False else : self.SetHandlerState(previous_handler, 'NotInstalled') Log('Uninstall complete'+ previous_handler ) SimpleLog(p.plugin_log,'Uninstall complete'+ name +'-' + previous_version) AddExtensionEvent(name,WALAEventOperation.Upgrade,isupgradeSuccess,0,previous_version) else : # run install if self.launchCommand(p.plugin_log,name,version,getcmd) == None : self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for '+name+'-'+version) SimpleLog(p.plugin_log,'Installation failed for '+name+'-'+version) else : self.SetHandlerState(handler, 'Installed') Log('Installation completed for '+name+'-'+version) SimpleLog(p.plugin_log,'Installation completed for '+name+'-'+version) #end if plg_dir == none or version > = prev # change incarnation of settings file so it knows how to name status... zip_dir=LibDir+"/" + name + '-' + version mfile=None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile=os.path.join(root,f) if mfile != None: break if mfile == None : Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log,'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata',manifest) config='' seqNo='0' if len(dom.getElementsByTagName("PluginSettings")) != 0 : try: pslist=dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") except: Error('Error parsing ExtensionsConfig.') SimpleLog(p.plugin_log,'Error parsing ExtensionsConfig.') continue for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log,"Found RuntimeSettings for " + name + " V " + version) config=GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo=ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Error("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log,"No RuntimeSettings for " + name + " V " + version) SetFileContents(root +"/config/" + seqNo +".settings", config ) # state is still enable if (self.GetHandlerState(handler) == 'NotInstalled'): # run install first if true if self.launchCommand(p.plugin_log,name,version,'installCommand') == None : self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for '+name+'-'+version) SimpleLog(p.plugin_log,'Installation failed for '+name+'-'+version) else : self.SetHandlerState(handler, 'Installed') Log('Installation completed for '+name+'-'+version) SimpleLog(p.plugin_log,'Installation completed for '+name+'-'+version) if (self.GetHandlerState(handler) != 'NotInstalled'): if self.launchCommand(p.plugin_log,name,version,'enableCommand') == None : self.SetHandlerState(handler, 'Installed') Error('Enable failed for '+name+'-'+version) SimpleLog(p.plugin_log,'Enable failed for '+name+'-'+version) else : self.SetHandlerState(handler, 'Enabled') Log('Enable completed for '+name+'-'+version) SimpleLog(p.plugin_log,'Enable completed for '+name+'-'+version) # this plugin processing is complete Log('Processing completed for '+name+'-'+version) SimpleLog(p.plugin_log,'Processing completed for '+name+'-'+version) #end plugin processing loop Log('Finished processing ExtensionsConfig.xml') try: SimpleLog(p.plugin_log,'Finished processing ExtensionsConfig.xml') except: pass return self def launchCommand(self,plugin_log,name,version,command,prev_version=None): commandToEventOperation={ "installCommand":WALAEventOperation.Install, "uninstallCommand":WALAEventOperation.UnIsntall, "updateCommand": WALAEventOperation.Upgrade, "enableCommand": WALAEventOperation.Enable, "disableCommand": WALAEventOperation.Disable, } isSuccess=True start = datetime.datetime.now() r=self.__launchCommandWithoutEventLog(plugin_log,name,version,command,prev_version) if r==None: isSuccess=False Duration = int((datetime.datetime.now() - start).seconds) if commandToEventOperation.get(command): AddExtensionEvent(name,commandToEventOperation[command],isSuccess,Duration,version) return r def __launchCommandWithoutEventLog(self,plugin_log,name,version,command,prev_version=None): # get the manifest and read the command mfile=None zip_dir=LibDir+"/" + name + '-' + version for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile=os.path.join(root,f) if mfile != None: break if mfile == None : Error('HandlerManifest.json not found.') SimpleLog(plugin_log,'HandlerManifest.json not found.') return None manifest = GetFileContents(mfile) try: jsn = json.loads(manifest) except: Error('Error parsing HandlerManifest.json.') SimpleLog(plugin_log,'Error parsing HandlerManifest.json.') return None if type(jsn)==list: jsn=jsn[0] if jsn.has_key('handlerManifest') : cmd = jsn['handlerManifest'][command] else : Error('Key handlerManifest not found. Handler cannot be installed.') SimpleLog(plugin_log,'Key handlerManifest not found. Handler cannot be installed.') if len(cmd) == 0 : Error('Unable to read ' + command ) SimpleLog(plugin_log,'Unable to read ' + command ) return None # for update we send the path of the old installation arg='' if prev_version != None : arg=' ' + LibDir+'/' + name + '-' + prev_version dirpath=os.path.dirname(mfile) LogIfVerbose('Command is '+ dirpath+'/'+ cmd) # launch pid=None try: child = subprocess.Popen(dirpath+'/'+cmd+arg,shell=True,cwd=dirpath,stdout=subprocess.PIPE) except Exception as e: Error('Exception launching ' + cmd + str(e)) SimpleLog(plugin_log,'Exception launching ' + cmd + str(e)) pid = child.pid if pid == None or pid < 1 : ExtensionChildren.append((-1,root)) Error('Error launching ' + cmd + '.') SimpleLog(plugin_log,'Error launching ' + cmd + '.') else : ExtensionChildren.append((pid,root)) Log("Spawned "+ cmd + " PID " + str(pid)) SimpleLog(plugin_log,"Spawned "+ cmd + " PID " + str(pid)) # wait until install/upgrade is finished timeout = 300 # 5 minutes retry = timeout/5 while retry > 0 and child.poll() == None: LogIfVerbose(cmd + ' still running with PID ' + str(pid)) time.sleep(5) retry-=1 if retry==0: Error('Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) SimpleLog(plugin_log,'Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) os.kill(pid,9) return None code = child.wait() if code == None or code != 0: Error('Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') SimpleLog(plugin_log,'Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') return None Log(command + ' completed.') SimpleLog(plugin_log,command + ' completed.') return 0 def ReportHandlerStatus(self): """ Collect all status reports. """ # { "version": "1.0", "timestampUTC": "2014-03-31T21:28:58Z", # "aggregateStatus": { # "guestAgentStatus": { "version": "2.0.4PRE", "status": "Ready", "formattedMessage": { "lang": "en-US", "message": "GuestAgent is running and accepting new configurations." } }, # "handlerAggregateStatus": [{ # "handlerName": "ExampleHandlerLinux", "handlerVersion": "1.0", "status": "Ready", "runtimeSettingsStatus": { # "sequenceNumber": "2", "settingsStatus": { "timestampUTC": "2014-03-31T23:46:00Z", "status": { "name": "ExampleHandlerLinux", "operation": "Command Execution Finished", "configurationAppliedTime": "2014-03-31T23:46:00Z", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Finished executing command" }, # "substatus": [ # { "name": "StdOut", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Goodbye world!" } }, # { "name": "StdErr", "status": "success", "formattedMessage": { "lang": "en-US", "message": "" } } # ] # } } } } # ] # }} try: incarnation=self.Extensions[0].getAttribute("goalStateIncarnation") except: Error('Error parsing ExtensionsConfig. Unable to send status reports') return None status='' statuses='' for p in self.Plugins: if p.getAttribute("state") == 'uninstall' or p.getAttribute("restricted") == 'true' : continue version=p.getAttribute("version") name=p.getAttribute("name") if p.getAttribute("isJson") != 'true': LogIfVerbose("Plugin " + name+" version: " +version+" is not a JSON Extension. Skipping.") continue reportHeartbeat = False if len(p.getAttribute("manifestdata"))<1: Error("Failed to get manifestdata.") else: reportHeartbeat = json.loads(p.getAttribute("manifestdata"))[0]['handlerManifest']['reportHeartbeat'] if len(statuses)>0: statuses+=',' statuses+=self.GenerateAggStatus(name, version, reportHeartbeat) tstamp=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) #header #agent state if provisioned == False: if provisionError == None : agent_state='Provisioning' agent_msg='Guest Agent is starting.' else: agent_state='Provisioning Error.' agent_msg=provisionError else: agent_state='Ready' agent_msg='GuestAgent is running and accepting new configurations.' status='{"version":"1.0","timestampUTC":"'+tstamp+'","aggregateStatus":{"guestAgentStatus":{"version":"'+GuestAgentVersion+'","status":"'+agent_state+'","formattedMessage":{"lang":"en-US","message":"'+agent_msg+'"}},"handlerAggregateStatus":['+statuses+']}}' try: uri=GetNodeTextData(self.Extensions[0].getElementsByTagName("StatusUploadBlob")[0]).replace('&','&') except: Error('Error parsing ExtensionsConfig. Unable to send status reports') return None UploadStatusBlob(uri, status.encode("utf-8")) LogIfVerbose('Status report '+status+' sent to ' + uri) return True def GetCurrentSequenceNumber(self, plugin_base_dir): """ Get the settings file with biggest file number in config folder """ config_dir = os.path.join(plugin_base_dir, 'config') seq_no = 0 for subdir, dirs, files in os.walk(config_dir): for file in files: try: cur_seq_no = int(os.path.basename(file).split('.')[0]) if cur_seq_no > seq_no: seq_no = cur_seq_no except ValueError: continue return str(seq_no) def GenerateAggStatus(self, name, version, reportHeartbeat = False): """ Generate the status which Azure can understand by the status and heartbeat reported by extension """ plugin_base_dir = LibDir+'/'+name+'-'+version+'/' current_seq_no = self.GetCurrentSequenceNumber(plugin_base_dir) status_file=os.path.join(plugin_base_dir, 'status/', current_seq_no +'.status') heartbeat_file = os.path.join(plugin_base_dir, 'heartbeat.log') handler_state_file = os.path.join(plugin_base_dir, 'config', 'HandlerState') agg_state = 'NotReady' handler_state = None status_obj = None status_code = None formatted_message = None localized_message = None if os.path.exists(handler_state_file): handler_state = GetFileContents(handler_state_file).lower() if HandlerStatusToAggStatus.has_key(handler_state): agg_state = HandlerStatusToAggStatus[handler_state] if reportHeartbeat: if os.path.exists(heartbeat_file): d=int(time.time()-os.stat(heartbeat_file).st_mtime) if d > 600 : # not updated for more than 10 min agg_state = 'Unresponsive' else: try: heartbeat = json.loads(GetFileContents(heartbeat_file))[0]["heartbeat"] agg_state = heartbeat.get("status") status_code = heartbeat.get("code") formatted_message = heartbeat.get("formattedMessage") localized_message = heartbeat.get("message") except: Error("Incorrect heartbeat file. Ignore it. ") else: agg_state = 'Unresponsive' #get status file reported by extension if os.path.exists(status_file): # raw status generated by extension is an array, get the first item and remove the unnecessary element try: status_obj = json.loads(GetFileContents(status_file))[0] del status_obj["version"] except: Error("Incorrect status file. Will NOT settingsStatus in settings. ") agg_status_obj = {"handlerName": name, "handlerVersion": version, "status": agg_state, "runtimeSettingsStatus" : {"sequenceNumber": current_seq_no}} if status_obj: agg_status_obj["runtimeSettingsStatus"]["settingsStatus"] = status_obj if status_code != None: agg_status_obj["code"] = status_code if formatted_message: agg_status_obj["formattedMessage"] = formatted_message if localized_message: agg_status_obj["message"] = localized_message agg_status_string = json.dumps(agg_status_obj) LogIfVerbose("Handler Aggregated Status:" + agg_status_string) return agg_status_string def SetHandlerState(self, handler, state=''): zip_dir=LibDir+"/" + handler mfile=None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile=os.path.join(root,f) if mfile != None: break if mfile == None : Error('SetHandlerState(): HandlerManifest.json not found, cannot set HandlerState.') return None Log("SetHandlerState: "+handler+", "+state) return SetFileContents(os.path.dirname(mfile)+'/config/HandlerState', state) def GetHandlerState(self, handler): handlerState = GetFileContents(handler+'/config/HandlerState') if (handlerState): return handlerState.rstrip('\r\n') else: return 'NotInstalled' class HostingEnvironmentConfig(object): """ Parse Hosting enviromnet config and store in HostingEnvironmentConfig.xml """ # # # # # # # # # # # # # # # # # # # # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset Members. """ self.StoredCertificates = None self.Deployment = None self.Incarnation = None self.Role = None self.HostingEnvironmentSettings = None self.ApplicationSettings = None self.Certificates = None self.ResourceReferences = None def Parse(self, xmlText): """ Parse and create HostingEnvironmentConfig.xml. """ self.reinitialize() SetFileContents("HostingEnvironmentConfig.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in [ "HostingEnvironmentConfig", "Deployment", "Service", "ServiceInstance", "Incarnation", "Role", ]: if not dom.getElementsByTagName(a): Error("HostingEnvironmentConfig.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "HostingEnvironmentConfig": Error("HostingEnvironmentConfig.Parse: root not HostingEnvironmentConfig") return None self.ApplicationSettings = dom.getElementsByTagName("Setting") self.Certificates = dom.getElementsByTagName("StoredCertificate") return self def DecryptPassword(self, e): """ Return decrypted password. """ SetFileContents("password.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"password.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"password.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + textwrap.fill(e, 64)) return RunGetOutput(Openssl + " cms -decrypt -in password.p7m -inkey Certificates.pem -recip Certificates.pem")[1] def ActivateResourceDisk(self): return MyDistro.ActivateResourceDisk() def Process(self): """ Execute ActivateResourceDisk in separate thread. Create the user account. Launch ConfigurationConsumer if specified in the config. """ no_thread = False if DiskActivated == False: for m in inspect.getmembers(MyDistro): if 'ActivateResourceDiskNoThread' in m: no_thread = True break if no_thread == True : MyDistro.ActivateResourceDiskNoThread() else : diskThread = threading.Thread(target = self.ActivateResourceDisk) diskThread.start() User = None Pass = None Expiration = None Thumbprint = None for b in self.ApplicationSettings: sname = b.getAttribute("name") svalue = b.getAttribute("value") if User != None and Pass != None: if User != "root" and User != "" and Pass != "": CreateAccount(User, Pass, Expiration, Thumbprint) else: Error("Not creating user account: " + User) for c in self.Certificates: csha1 = c.getAttribute("certificateId").split(':')[1].upper() if os.path.isfile(csha1 + ".prv"): Log("Private key with thumbprint: " + csha1 + " was retrieved.") if os.path.isfile(csha1 + ".crt"): Log("Public cert with thumbprint: " + csha1 + " was retrieved.") program = Config.get("Role.ConfigurationConsumer") if program != None: try: Children.append(subprocess.Popen([program, LibDir + "/HostingEnvironmentConfig.xml"])) except OSError, e : ErrorWithPrefix('HostingEnvironmentConfig.Process','Exception: '+ str(e) +' occured launching ' + program ) class GoalState(Util): """ Primary container for all configuration except OvfXml. Encapsulates http communication with endpoint server. Initializes and populates: self.HostingEnvironmentConfig self.SharedConfig self.ExtensionsConfig self.Certificates """ # # # 2010-12-15 # 1 # # Started # # 16001 # # # # c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2 # # # MachineRole_IN_0 # Started # # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=config&type=hostingEnvironmentConfig&incarnation=1 # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=config&type=sharedConfig&incarnation=1 # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=certificates&incarnation=1 # http://100.67.238.230:80/machine/9c87aa94-3bda-45e3-b2b7-0eb0fca7baff/1552dd64dc254e6884f8d5b8b68aa18f.eg%2Dplug%2Dvm?comp=config&type=extensionsConfig&incarnation=2 # http://100.67.238.230:80/machine/9c87aa94-3bda-45e3-b2b7-0eb0fca7baff/1552dd64dc254e6884f8d5b8b68aa18f.eg%2Dplug%2Dvm?comp=config&type=fullConfig&incarnation=2 # # # # # # # There is only one Role for VM images. # # Of primary interest is: # LBProbePorts -- an http server needs to run here # We also note Container/ContainerID and RoleInstance/InstanceId to form the health report. # And of course, Incarnation # def __init__(self, Agent): self.Agent = Agent self.Endpoint = Agent.Endpoint self.TransportCert = Agent.TransportCert self.reinitialize() def reinitialize(self): self.Incarnation = None # integer self.ExpectedState = None # "Started" self.HostingEnvironmentConfigUrl = None self.HostingEnvironmentConfigXml = None self.HostingEnvironmentConfig = None self.SharedConfigUrl = None self.SharedConfigXml = None self.SharedConfig = None self.CertificatesUrl = None self.CertificatesXml = None self.Certificates = None self.ExtensionsConfigUrl = None self.ExtensionsConfigXml = None self.ExtensionsConfig = None self.RoleInstanceId = None self.ContainerId = None self.LoadBalancerProbePort = None # integer, ?list of integers def Parse(self, xmlText): """ Request configuration data from endpoint server. Parse and populate contained configuration objects. Calls Certificates().Parse() Calls SharedConfig().Parse Calls ExtensionsConfig().Parse Calls HostingEnvironmentConfig().Parse """ self.reinitialize() LogIfVerbose(xmlText) node = xml.dom.minidom.parseString(xmlText).childNodes[0] if node.localName != "GoalState": Error("GoalState.Parse: root not GoalState") return None for a in node.childNodes: if a.nodeType == node.ELEMENT_NODE: if a.localName == "Incarnation": self.Incarnation = GetNodeTextData(a) elif a.localName == "Machine": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE: if b.localName == "ExpectedState": self.ExpectedState = GetNodeTextData(b) Log("ExpectedState: " + self.ExpectedState) elif b.localName == "LBProbePorts": for c in b.childNodes: if c.nodeType == node.ELEMENT_NODE and c.localName == "Port": self.LoadBalancerProbePort = int(GetNodeTextData(c)) elif a.localName == "Container": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE: if b.localName == "ContainerId": self.ContainerId = GetNodeTextData(b) Log("ContainerId: " + self.ContainerId) elif b.localName == "RoleInstanceList": for c in b.childNodes: if c.localName == "RoleInstance": for d in c.childNodes: if d.nodeType == node.ELEMENT_NODE: if d.localName == "InstanceId": self.RoleInstanceId = GetNodeTextData(d) Log("RoleInstanceId: " + self.RoleInstanceId) elif d.localName == "State": pass elif d.localName == "Configuration": for e in d.childNodes: if e.nodeType == node.ELEMENT_NODE: LogIfVerbose(e.localName) if e.localName == "HostingEnvironmentConfig": self.HostingEnvironmentConfigUrl = GetNodeTextData(e) LogIfVerbose("HostingEnvironmentConfigUrl:" + self.HostingEnvironmentConfigUrl) self.HostingEnvironmentConfigXml = self.HttpGetWithHeaders(self.HostingEnvironmentConfigUrl) self.HostingEnvironmentConfig = HostingEnvironmentConfig().Parse(self.HostingEnvironmentConfigXml) elif e.localName == "SharedConfig": self.SharedConfigUrl = GetNodeTextData(e) LogIfVerbose("SharedConfigUrl:" + self.SharedConfigUrl) self.SharedConfigXml = self.HttpGetWithHeaders(self.SharedConfigUrl) self.SharedConfig = SharedConfig().Parse(self.SharedConfigXml) self.SharedConfig.Save() elif e.localName == "ExtensionsConfig": self.ExtensionsConfigUrl = GetNodeTextData(e) LogIfVerbose("ExtensionsConfigUrl:" + self.ExtensionsConfigUrl) self.ExtensionsConfigXml = self.HttpGetWithHeaders(self.ExtensionsConfigUrl) elif e.localName == "Certificates": self.CertificatesUrl = GetNodeTextData(e) LogIfVerbose("CertificatesUrl:" + self.CertificatesUrl) self.CertificatesXml = self.HttpSecureGetWithHeaders(self.CertificatesUrl, self.TransportCert) self.Certificates = Certificates().Parse(self.CertificatesXml) if self.Incarnation == None: Error("GoalState.Parse: Incarnation missing") return None if self.ExpectedState == None: Error("GoalState.Parse: ExpectedState missing") return None if self.RoleInstanceId == None: Error("GoalState.Parse: RoleInstanceId missing") return None if self.ContainerId == None: Error("GoalState.Parse: ContainerId missing") return None SetFileContents("GoalState." + self.Incarnation + ".xml", xmlText) return self def Process(self): """ Calls HostingEnvironmentConfig.Process() """ LogIfVerbose("Process goalstate") self.HostingEnvironmentConfig.Process() self.SharedConfig.Process() class OvfEnv(object): """ Read, and process provisioning info from provisioning file OvfEnv.xml """ # # # # # 1.0 # # LinuxProvisioningConfiguration # HostName # UserName # UserPassword # false # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/authorized_keys # # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/id_rsa # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.WaNs = "http://schemas.microsoft.com/windowsazure" self.OvfNs = "http://schemas.dmtf.org/ovf/environment/1" self.MajorVersion = 1 self.MinorVersion = 0 self.ComputerName = None self.AdminPassword = None self.UserName = None self.UserPassword = None self.CustomData = None self.DisableSshPasswordAuthentication = True self.SshPublicKeys = [] self.SshKeyPairs = [] def Parse(self, xmlText, isDeprovision = False): """ Parse xml tree, retreiving user and ssh key information. Return self. """ self.reinitialize() LogIfVerbose(re.sub(".*?<", "*<", xmlText)) dom = xml.dom.minidom.parseString(xmlText) if len(dom.getElementsByTagNameNS(self.OvfNs, "Environment")) != 1: Error("Unable to parse OVF XML.") section = None newer = False for p in dom.getElementsByTagNameNS(self.WaNs, "ProvisioningSection"): for n in p.childNodes: if n.localName == "Version": verparts = GetNodeTextData(n).split('.') major = int(verparts[0]) minor = int(verparts[1]) if major > self.MajorVersion: newer = True if major != self.MajorVersion: break if minor > self.MinorVersion: newer = True section = p if newer == True: Warn("Newer provisioning configuration detected. Please consider updating waagent.") if section == None: Error("Could not find ProvisioningSection with major version=" + str(self.MajorVersion)) return None self.ComputerName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "HostName")[0]) self.UserName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserName")[0]) if isDeprovision == True: return self try: self.UserPassword = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserPassword")[0]) except: pass CDSection=None try: CDSection=section.getElementsByTagNameNS(self.WaNs, "CustomData") if len(CDSection) > 0 : self.CustomData=GetNodeTextData(CDSection[0]) if len(self.CustomData)>0: SetFileContents(LibDir + '/CustomData', MyDistro.translateCustomData(self.CustomData)) Log('Wrote ' + LibDir + '/CustomData') else : Error(' contains no data!') except Exception, e: Error( str(e)+' occured creating ' + LibDir + '/CustomData') disableSshPass = section.getElementsByTagNameNS(self.WaNs, "DisableSshPasswordAuthentication") if len(disableSshPass) != 0: self.DisableSshPasswordAuthentication = (GetNodeTextData(disableSshPass[0]).lower() == "true") for pkey in section.getElementsByTagNameNS(self.WaNs, "PublicKey"): LogIfVerbose(repr(pkey)) fp = None path = None for c in pkey.childNodes: if c.localName == "Fingerprint": fp = GetNodeTextData(c).upper() LogIfVerbose(fp) if c.localName == "Path": path = GetNodeTextData(c) LogIfVerbose(path) self.SshPublicKeys += [[fp, path]] for keyp in section.getElementsByTagNameNS(self.WaNs, "KeyPair"): fp = None path = None LogIfVerbose(repr(keyp)) for c in keyp.childNodes: if c.localName == "Fingerprint": fp = GetNodeTextData(c).upper() LogIfVerbose(fp) if c.localName == "Path": path = GetNodeTextData(c) LogIfVerbose(path) self.SshKeyPairs += [[fp, path]] return self def PrepareDir(self, filepath): """ Create home dir for self.UserName Change owner and return path. """ home = MyDistro.GetHome() # Expand HOME variable if present in path path = os.path.normpath(filepath.replace("$HOME", home)) if (path.startswith("/") == False) or (path.endswith("/") == True): return None dir = path.rsplit('/', 1)[0] if dir != "": CreateDir(dir, "root", 0700) if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(dir, self.UserName) return path def NumberToBytes(self, i): """ Pack number into bytes. Retun as string. """ result = [] while i: result.append(chr(i & 0xFF)) i >>= 8 result.reverse() return ''.join(result) def BitsToString(self, a): """ Return string representation of bits in a. """ index=7 s = "" c = 0 for bit in a: c = c | (bit << index) index = index - 1 if index == -1: s = s + struct.pack('>B', c) c = 0 index = 7 return s def OpensslToSsh(self, file): """ Return base-64 encoded key appropriate for ssh. """ from pyasn1.codec.der import decoder as der_decoder try: f = open(file).read().replace('\n','').split("KEY-----")[1].split('-')[0] k=der_decoder.decode(self.BitsToString(der_decoder.decode(base64.b64decode(f))[0][1]))[0] n=k[0] e=k[1] keydata="" keydata += struct.pack('>I',len("ssh-rsa")) keydata += "ssh-rsa" keydata += struct.pack('>I',len(self.NumberToBytes(e))) keydata += self.NumberToBytes(e) keydata += struct.pack('>I',len(self.NumberToBytes(n)) + 1) keydata += "\0" keydata += self.NumberToBytes(n) except Exception, e: print("OpensslToSsh: Exception " + str(e)) return None return "ssh-rsa " + base64.b64encode(keydata) + "\n" def Process(self): """ Process all certificate and key info. DisableSshPasswordAuthentication if configured. CreateAccount(user) Wait for WaAgent.EnvMonitor.IsHostnamePublished(). Restart ssh service. """ error = None if self.ComputerName == None : return "Error: Hostname missing" error=WaAgent.EnvMonitor.SetHostName(self.ComputerName) if error: return error if self.DisableSshPasswordAuthentication: filepath = "/etc/ssh/sshd_config" # Disable RFC 4252 and RFC 4256 authentication schemes. ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not (a.startswith("PasswordAuthentication") or a.startswith("ChallengeResponseAuthentication")), GetFileContents(filepath).split('\n'))) + "\nPasswordAuthentication no\nChallengeResponseAuthentication no\n") Log("Disabled SSH password-based authentication methods.") if self.AdminPassword != None: MyDistro.changePass('root',self.AdminPassword) if self.UserName != None: error = MyDistro.CreateAccount(self.UserName, self.UserPassword, None, None) sel = MyDistro.isSelinuxRunning() if sel : MyDistro.setSelinuxEnforce(0) home = MyDistro.GetHome() for pkey in self.SshPublicKeys: Log("Deploy public key:{0}".format(pkey[0])) if not os.path.isfile(pkey[0] + ".crt"): Error("PublicKey not found: " + pkey[0]) error = "Failed to deploy public key (0x09)." continue path = self.PrepareDir(pkey[1]) if path == None: Error("Invalid path: " + pkey[1] + " for PublicKey: " + pkey[0]) error = "Invalid path for public key (0x03)." continue Run(Openssl + " x509 -in " + pkey[0] + ".crt -noout -pubkey > " + pkey[0] + ".pub") MyDistro.setSelinuxContext(pkey[0] + '.pub','unconfined_u:object_r:ssh_home_t:s0') MyDistro.sshDeployPublicKey(pkey[0] + '.pub',path) MyDistro.setSelinuxContext(path,'unconfined_u:object_r:ssh_home_t:s0') if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(path, self.UserName) for keyp in self.SshKeyPairs: Log("Deploy key pair:{0}".format(keyp[0])) if not os.path.isfile(keyp[0] + ".prv"): Error("KeyPair not found: " + keyp[0]) error = "Failed to deploy key pair (0x0A)." continue path = self.PrepareDir(keyp[1]) if path == None: Error("Invalid path: " + keyp[1] + " for KeyPair: " + keyp[0]) error = "Invalid path for key pair (0x05)." continue SetFileContents(path, GetFileContents(keyp[0] + ".prv")) os.chmod(path, 0600) Run("ssh-keygen -y -f " + keyp[0] + ".prv > " + path + ".pub") MyDistro.setSelinuxContext(path,'unconfined_u:object_r:ssh_home_t:s0') MyDistro.setSelinuxContext(path + '.pub','unconfined_u:object_r:ssh_home_t:s0') if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(path, self.UserName) ChangeOwner(path + ".pub", self.UserName) if sel : MyDistro.setSelinuxEnforce(1) while not WaAgent.EnvMonitor.IsHostnamePublished(): time.sleep(1) MyDistro.restartSshService() return error class WALAEvent(object): def __init__(self): self.providerId="" self.eventId=1 self.OpcodeName="" self.KeywordName="" self.TaskName="" self.TenantName="" self.RoleName="" self.RoleInstanceName="" self.ContainerId="" self.ExecutionMode="IAAS" self.OSVersion="" self.GAVersion="" self.RAM=0 self.Processors=0 def ToXml(self): strEventid=u''.format(self.eventId) strProviderid=u''.format(self.providerId) strRecordFormat = u'' strRecordNoQuoteFormat = u'' strMtStr=u'mt:wstr' strMtUInt64=u'mt:uint64' strMtBool=u'mt:bool' strMtFloat=u'mt:float64' strEventsData=u"" for attName in self.__dict__: if attName in ["eventId","filedCount","providerId"]: continue attValue = self.__dict__[attName] if type(attValue) is int: strEventsData+=strRecordFormat.format(attName,attValue,strMtUInt64) continue if type(attValue) is str: attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData+=strRecordNoQuoteFormat.format(attName,attValue,strMtStr) continue if str(type(attValue)).count("'unicode'") >0 : attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData+=strRecordNoQuoteFormat.format(attName,attValue,strMtStr) continue if type(attValue) is bool: strEventsData+=strRecordFormat.format(attName,attValue,strMtBool) continue if type(attValue) is float: strEventsData+=strRecordFormat.format(attName,attValue,strMtFloat) continue Log("Warning: property "+attName+":"+str(type(attValue))+":type"+str(type(attValue))+"Can't convert to events data:"+":type not supported") return u"{0}{1}{2}".format(strProviderid,strEventid,strEventsData) def Save(self): eventfolder = LibDir+"/events" if not os.path.exists(eventfolder): os.mkdir(eventfolder) os.chmod(eventfolder,0700) if len(os.listdir(eventfolder)) > 1000: raise Exception("WriteToFolder:Too many file under "+eventfolder+" exit") filename = os.path.join(eventfolder,str(int(time.time()*1000000))) with open(filename+".tmp",'wb+') as hfile: hfile.write(self.ToXml().encode("utf-8")) os.rename(filename+".tmp",filename+".tld") class WALAEventOperation: HeartBeat="HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" def AddExtensionEvent(name,op,isSuccess,duration=0,version="1.0",message="",type="",isInternal=False): event = ExtensionEvent() event.Name=name event.Version=version event.IsInternal=isInternal event.Operation=op event.OperationSuccess=isSuccess event.Message=message event.Duration=duration event.ExtensionType=type try: event.Save() except: Error("Error "+traceback.format_exc()) class ExtensionEvent(WALAEvent): def __init__(self): WALAEvent.__init__(self) self.eventId=1 self.providerId="69B669B9-4AF8-4C50-BDC4-6006FA76E975" self.Name="" self.Version="" self.IsInternal=False self.Operation="" self.OperationSuccess=True self.ExtensionType="" self.Message="" self.Duration=0 class WALAEventMonitor(WALAEvent): def __init__(self,postMethod): WALAEvent.__init__(self) self.post = postMethod self.sysInfo={} self.eventdir = LibDir+"/events" self.issysteminfoinitilized = False def StartEventsLoop(self): eventThread = threading.Thread(target = self.EventsLoop) eventThread.setDaemon(True) eventThread.start() def EventsLoop(self): LastReportHeartBeatTime = datetime.datetime.min try: while(True): if (datetime.datetime.now()-LastReportHeartBeatTime) > datetime.timedelta(hours=12): LastReportHeartBeatTime = datetime.datetime.now() AddExtensionEvent(op=WALAEventOperation.HeartBeat,name="WALA",isSuccess=True) self.postNumbersInOneLoop=0 self.CollectAndSendWALAEvents() time.sleep(60) except: Error("Exception in events loop:"+traceback.format_exc()) def SendEvent(self,providerid,events): dataFormat = u'{1}'\ '' data = dataFormat.format(providerid,events) self.post("/machine/?comp=telemetrydata", data) def CollectAndSendWALAEvents(self): if not os.path.exists(self.eventdir): return #Throtting, can't send more than 3 events in 15 seconds eventSendNumber=0 eventFiles = os.listdir(self.eventdir) events = {} for file in eventFiles: if not file.endswith(".tld"): continue with open(os.path.join(self.eventdir,file),"rb") as hfile: #if fail to open or delete the file, throw exception xmlStr = hfile.read().decode("utf-8",'ignore') os.remove(os.path.join(self.eventdir,file)) params="" eventid="" providerid="" #if exception happen during process an event, catch it and continue try: xmlStr = self.AddSystemInfo(xmlStr) for node in xml.dom.minidom.parseString(xmlStr.encode("utf-8")).childNodes[0].childNodes: if node.tagName == "Param": params+=node.toxml() if node.tagName == "Event": eventid=node.getAttribute("id") if node.tagName == "Provider": providerid = node.getAttribute("id") except: Error(traceback.format_exc()) continue if len(params)==0 or len(eventid)==0 or len(providerid)==0: Error("Empty filed in params:"+params+" event id:"+eventid+" provider id:"+providerid) continue eventstr = u''.format(eventid,params) if not events.get(providerid): events[providerid]="" if len(events[providerid]) >0 and len(events.get(providerid)+eventstr)>= 63*1024: eventSendNumber+=1 self.SendEvent(providerid,events.get(providerid)) if eventSendNumber %3 ==0: time.sleep(15) events[providerid]="" if len(eventstr) >= 63*1024: Error("Signle event too large abort "+eventstr[:300]) continue events[providerid]=events.get(providerid)+eventstr for key in events.keys(): if len(events[key]) > 0: eventSendNumber+=1 self.SendEvent(key,events[key]) if eventSendNumber%3 == 0: time.sleep(15) def AddSystemInfo(self,eventData): if not self.issysteminfoinitilized: self.issysteminfoinitilized=True try: self.sysInfo["OSVersion"]=platform.system()+":"+"-".join(DistInfo(1))+":"+platform.release() self.sysInfo["GAVersion"]=GuestAgentVersion self.sysInfo["RAM"]=MyDistro.getTotalMemory() self.sysInfo["Processors"]=MyDistro.getProcessorCores() sharedConfig = xml.dom.minidom.parse("/var/lib/waagent/SharedConfig.xml").childNodes[0] hostEnvConfig= xml.dom.minidom.parse("/var/lib/waagent/HostingEnvironmentConfig.xml").childNodes[0] gfiles = RunGetOutput("ls -t /var/lib/waagent/GoalState.*.xml")[1] goalStateConfi = xml.dom.minidom.parse(gfiles.split("\n")[0]).childNodes[0] self.sysInfo["TenantName"]=hostEnvConfig.getElementsByTagName("Deployment")[0].getAttribute("name") self.sysInfo["RoleName"]=hostEnvConfig.getElementsByTagName("Role")[0].getAttribute("name") self.sysInfo["RoleInstanceName"]=sharedConfig.getElementsByTagName("Instance")[0].getAttribute("id") self.sysInfo["ContainerId"]=goalStateConfi.getElementsByTagName("ContainerId")[0].childNodes[0].nodeValue except: Error(traceback.format_exc()) eventObject = xml.dom.minidom.parseString(eventData.encode("utf-8")).childNodes[0] for node in eventObject.childNodes: if node.tagName == "Param": name = node.getAttribute("Name") if self.sysInfo.get(name): node.setAttribute("Value",xml.sax.saxutils.escape(str(self.sysInfo[name]))) return eventObject.toxml() class Agent(Util): """ Primary object container for the provisioning process. """ def __init__(self): self.GoalState = None self.Endpoint = None self.LoadBalancerProbeServer = None self.HealthReportCounter = 0 self.TransportCert = "" self.EnvMonitor = None self.SendData = None self.DhcpResponse = None def CheckVersions(self): """ Query endpoint server for wire protocol version. Fail if our desired protocol version is not seen. """ # # # # 2010-12-15 # # # 2010-12-15 # 2010-28-10 # # global ProtocolVersion protocolVersionSeen = False node = xml.dom.minidom.parseString(self.HttpGetWithoutHeaders("/?comp=versions")).childNodes[0] if node.localName != "Versions": Error("CheckVersions: root not Versions") return False for a in node.childNodes: if a.nodeType == node.ELEMENT_NODE and a.localName == "Supported": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE and b.localName == "Version": v = GetNodeTextData(b) LogIfVerbose("Fabric supported wire protocol version: " + v) if v == ProtocolVersion: protocolVersionSeen = True if a.nodeType == node.ELEMENT_NODE and a.localName == "Preferred": v = GetNodeTextData(a.getElementsByTagName("Version")[0]) Log("Fabric preferred wire protocol version: " + v) if not protocolVersionSeen: Warn("Agent supported wire protocol version: " + ProtocolVersion + " was not advertised by Fabric.") else: Log("Negotiated wire protocol version: " + ProtocolVersion) return True def Unpack(self, buffer, offset, range): """ Unpack bytes into python values. """ result = 0 for i in range: result = (result << 8) | Ord(buffer[offset + i]) return result def UnpackLittleEndian(self, buffer, offset, length): """ Unpack little endian bytes into python values. """ return self.Unpack(buffer, offset, list(range(length - 1, -1, -1))) def UnpackBigEndian(self, buffer, offset, length): """ Unpack big endian bytes into python values. """ return self.Unpack(buffer, offset, list(range(0, length))) def HexDump3(self, buffer, offset, length): """ Dump range of buffer in formatted hex. """ return ''.join(['%02X' % Ord(char) for char in buffer[offset:offset + length]]) def HexDump2(self, buffer): """ Dump buffer in formatted hex. """ return self.HexDump3(buffer, 0, len(buffer)) def BuildDhcpRequest(self): """ Build DHCP request string. """ # # typedef struct _DHCP { # UINT8 Opcode; /* op: BOOTREQUEST or BOOTREPLY */ # UINT8 HardwareAddressType; /* htype: ethernet */ # UINT8 HardwareAddressLength; /* hlen: 6 (48 bit mac address) */ # UINT8 Hops; /* hops: 0 */ # UINT8 TransactionID[4]; /* xid: random */ # UINT8 Seconds[2]; /* secs: 0 */ # UINT8 Flags[2]; /* flags: 0 or 0x8000 for broadcast */ # UINT8 ClientIpAddress[4]; /* ciaddr: 0 */ # UINT8 YourIpAddress[4]; /* yiaddr: 0 */ # UINT8 ServerIpAddress[4]; /* siaddr: 0 */ # UINT8 RelayAgentIpAddress[4]; /* giaddr: 0 */ # UINT8 ClientHardwareAddress[16]; /* chaddr: 6 byte ethernet MAC address */ # UINT8 ServerName[64]; /* sname: 0 */ # UINT8 BootFileName[128]; /* file: 0 */ # UINT8 MagicCookie[4]; /* 99 130 83 99 */ # /* 0x63 0x82 0x53 0x63 */ # /* options -- hard code ours */ # # UINT8 MessageTypeCode; /* 53 */ # UINT8 MessageTypeLength; /* 1 */ # UINT8 MessageType; /* 1 for DISCOVER */ # UINT8 End; /* 255 */ # } DHCP; # # tuple of 244 zeros # (struct.pack_into would be good here, but requires Python 2.5) sendData = [0] * 244 transactionID = os.urandom(4) macAddress = MyDistro.GetMacAddress() # Opcode = 1 # HardwareAddressType = 1 (ethernet/MAC) # HardwareAddressLength = 6 (ethernet/MAC/48 bits) for a in range(0, 3): sendData[a] = [1, 1, 6][a] # fill in transaction id (random number to ensure response matches request) for a in range(0, 4): sendData[4 + a] = Ord(transactionID[a]) LogIfVerbose("BuildDhcpRequest: transactionId:%s,%04X" % (self.HexDump2(transactionID), self.UnpackBigEndian(sendData, 4, 4))) # fill in ClientHardwareAddress for a in range(0, 6): sendData[0x1C + a] = Ord(macAddress[a]) # DHCP Magic Cookie: 99, 130, 83, 99 # MessageTypeCode = 53 DHCP Message Type # MessageTypeLength = 1 # MessageType = DHCPDISCOVER # End = 255 DHCP_END for a in range(0, 8): sendData[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a] return array.array("B", sendData) def IntegerToIpAddressV4String(self, a): """ Build DHCP request string. """ return "%u.%u.%u.%u" % ((a >> 24) & 0xFF, (a >> 16) & 0xFF, (a >> 8) & 0xFF, a & 0xFF) def RouteAdd(self, net, mask, gateway): """ Add specified route using /sbin/route add -net. """ net = self.IntegerToIpAddressV4String(net) mask = self.IntegerToIpAddressV4String(mask) gateway = self.IntegerToIpAddressV4String(gateway) Run("/sbin/route add -net " + net + " netmask " + mask + " gw " + gateway,chk_err=False) def HandleDhcpResponse(self, sendData, receiveBuffer): """ Parse DHCP response: Set default gateway. Set default routes. Retrieve endpoint server. Returns endpoint server or None on error. """ LogIfVerbose("HandleDhcpResponse") bytesReceived = len(receiveBuffer) if bytesReceived < 0xF6: Error("HandleDhcpResponse: Too few bytes received " + str(bytesReceived)) return None LogIfVerbose("BytesReceived: " + hex(bytesReceived)) LogWithPrefixIfVerbose("DHCP response:", HexDump(receiveBuffer, bytesReceived)) # check transactionId, cookie, MAC address # cookie should never mismatch # transactionId and MAC address may mismatch if we see a response meant from another machine for offsets in [list(range(4, 4 + 4)), list(range(0x1C, 0x1C + 6)), list(range(0xEC, 0xEC + 4))]: for offset in offsets: sentByte = Ord(sendData[offset]) receivedByte = Ord(receiveBuffer[offset]) if sentByte != receivedByte: LogIfVerbose("HandleDhcpResponse: sent cookie:" + self.HexDump3(sendData, 0xEC, 4)) LogIfVerbose("HandleDhcpResponse: rcvd cookie:" + self.HexDump3(receiveBuffer, 0xEC, 4)) LogIfVerbose("HandleDhcpResponse: sent transactionID:" + self.HexDump3(sendData, 4, 4)) LogIfVerbose("HandleDhcpResponse: rcvd transactionID:" + self.HexDump3(receiveBuffer, 4, 4)) LogIfVerbose("HandleDhcpResponse: sent ClientHardwareAddress:" + self.HexDump3(sendData, 0x1C, 6)) LogIfVerbose("HandleDhcpResponse: rcvd ClientHardwareAddress:" + self.HexDump3(receiveBuffer, 0x1C, 6)) LogIfVerbose("HandleDhcpResponse: transactionId, cookie, or MAC address mismatch") return None endpoint = None # # Walk all the returned options, parsing out what we need, ignoring the others. # We need the custom option 245 to find the the endpoint we talk to, # as well as, to handle some Linux DHCP client incompatibilities, # options 3 for default gateway and 249 for routes. And 255 is end. # i = 0xF0 # offset to first option while i < bytesReceived: option = Ord(receiveBuffer[i]) length = 0 if (i + 1) < bytesReceived: length = Ord(receiveBuffer[i + 1]) LogIfVerbose("DHCP option " + hex(option) + " at offset:" + hex(i) + " with length:" + hex(length)) if option == 255: LogIfVerbose("DHCP packet ended at offset " + hex(i)) break elif option == 249: # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx LogIfVerbose("Routes at offset:" + hex(i) + " with length:" + hex(length)) if length < 5: Error("Data too small for option " + str(option)) j = i + 2 while j < (i + length + 2): maskLengthBits = Ord(receiveBuffer[j]) maskLengthBytes = (((maskLengthBits + 7) & ~7) >> 3) mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - maskLengthBits)) j += 1 net = self.UnpackBigEndian(receiveBuffer, j, maskLengthBytes) net <<= (32 - maskLengthBytes * 8) net &= mask j += maskLengthBytes gateway = self.UnpackBigEndian(receiveBuffer, j, 4) j += 4 self.RouteAdd(net, mask, gateway) if j != (i + length + 2): Error("HandleDhcpResponse: Unable to parse routes") elif option == 3 or option == 245: if i + 5 < bytesReceived: if length != 4: Error("HandleDhcpResponse: Endpoint or Default Gateway not 4 bytes") return None gateway = self.UnpackBigEndian(receiveBuffer, i + 2, 4) IpAddress = self.IntegerToIpAddressV4String(gateway) if option == 3: self.RouteAdd(0, 0, gateway) name = "DefaultGateway" else: endpoint = IpAddress name = "Windows Azure wire protocol endpoint" LogIfVerbose(name + ": " + IpAddress + " at " + hex(i)) else: Error("HandleDhcpResponse: Data too small for option " + str(option)) else: LogIfVerbose("Skipping DHCP option " + hex(option) + " at " + hex(i) + " with length " + hex(length)) i += length + 2 return endpoint def DoDhcpWork(self): """ Discover the wire server via DHCP option 245. And workaround incompatibility with Windows Azure DHCP servers. """ ShortSleep = False # Sleep 1 second before retrying DHCP queries. ifname=None sleepDurations = [0, 10, 30, 60, 60] maxRetry = len(sleepDurations) lastTry = (maxRetry - 1) for retry in range(0, maxRetry): try: #Open DHCP port if iptables is enabled. Run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",chk_err=False) # We supress error logging on error. Run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",chk_err=False) # We supress error logging on error. strRetry = str(retry) prefix = "DoDhcpWork: try=" + strRetry LogIfVerbose(prefix) sendData = self.BuildDhcpRequest() LogWithPrefixIfVerbose("DHCP request:", HexDump(sendData, len(sendData))) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) missingDefaultRoute = True try: if DistInfo()[0] == 'FreeBSD': missingDefaultRoute = True else: routes = RunGetOutput("route -n")[1] for line in routes.split('\n'): if line.startswith("0.0.0.0 ") or line.startswith("default "): missingDefaultRoute = False except: pass if missingDefaultRoute: # This is required because sending after binding to 0.0.0.0 fails with # network unreachable when the default gateway is not set up. ifname=MyDistro.GetInterfaceName() Log("DoDhcpWork: Missing default route - adding broadcast route for DHCP.") if DistInfo()[0] == 'FreeBSD': Run("route add -net 255.255.255.255 -iface " + ifname,chk_err=False) else: Run("route add 255.255.255.255 dev " + ifname,chk_err=False) if MyDistro.isDHCPEnabled(): MyDistro.stopDHCP() sock.bind(("0.0.0.0", 68)) sock.sendto(sendData, ("", 67)) sock.settimeout(10) Log("DoDhcpWork: Setting socket.timeout=10, entering recv") receiveBuffer = sock.recv(1024) endpoint = self.HandleDhcpResponse(sendData, receiveBuffer) if endpoint == None: LogIfVerbose("DoDhcpWork: No endpoint found") if endpoint != None or retry == lastTry: if endpoint != None: self.SendData = sendData self.DhcpResponse = receiveBuffer if retry == lastTry: LogIfVerbose("DoDhcpWork: try=" + strRetry) return endpoint sleepDuration = [sleepDurations[retry % len(sleepDurations)], 1][ShortSleep] LogIfVerbose("DoDhcpWork: sleep=" + str(sleepDuration)) time.sleep(sleepDuration) except Exception, e: ErrorWithPrefix(prefix, str(e)) ErrorWithPrefix(prefix, traceback.format_exc()) finally: sock.close() if missingDefaultRoute: #We added this route - delete it Log("DoDhcpWork: Removing broadcast route for DHCP.") if DistInfo()[0] == 'FreeBSD': Run("route del -net 255.255.255.255 -iface " + ifname,chk_err=False) else: Run("route del 255.255.255.255 dev " + ifname,chk_err=False) # We supress error logging on error. if MyDistro.isDHCPEnabled(): MyDistro.startDHCP() return None def UpdateAndPublishHostName(self, name): """ Set hostname locally and publish to iDNS """ Log("Setting host name: " + name) MyDistro.publishHostname(name) ethernetInterface = MyDistro.GetInterfaceName() MyDistro.RestartInterface(ethernetInterface) self.RestoreRoutes() def RestoreRoutes(self): """ If there is a DHCP response, then call HandleDhcpResponse. """ if self.SendData != None and self.DhcpResponse != None: self.HandleDhcpResponse(self.SendData, self.DhcpResponse) def UpdateGoalState(self): """ Retreive goal state information from endpoint server. Parse xml and initialize Agent.GoalState object. Return object or None on error. """ goalStateXml = None maxRetry = 9 log = NoLog for retry in range(1, maxRetry + 1): strRetry = str(retry) log("retry UpdateGoalState,retry=" + strRetry) goalStateXml = self.HttpGetWithHeaders("/machine/?comp=goalstate") if goalStateXml != None: break log = Log time.sleep(retry) if not goalStateXml: Error("UpdateGoalState failed.") return Log("Retrieved GoalState from Windows Azure Fabric.") self.GoalState = GoalState(self).Parse(goalStateXml) return self.GoalState def ReportReady(self): """ Send health report 'Ready' to server. This signals the fabric that our provosion is completed, and the host is ready for operation. """ counter = (self.HealthReportCounter + 1) % 1000000 self.HealthReportCounter = counter healthReport = ("" + self.GoalState.Incarnation + "" + self.GoalState.ContainerId + "" + self.GoalState.RoleInstanceId + "Ready") a = self.HttpPostWithHeaders("/machine?comp=health", healthReport) if a != None: return a.getheader("x-ms-latest-goal-state-incarnation-number") return None def ReportNotReady(self, status, desc): """ Send health report 'Provisioning' to server. This signals the fabric that our provosion is starting. """ healthReport = ("" + self.GoalState.Incarnation + "" + self.GoalState.ContainerId + "" + self.GoalState.RoleInstanceId + "NotReady" + "
" + status + "" + desc + "
" + "
") a = self.HttpPostWithHeaders("/machine?comp=health", healthReport) if a != None: return a.getheader("x-ms-latest-goal-state-incarnation-number") return None def ReportRoleProperties(self, thumbprint): """ Send roleProperties and thumbprint to server. """ roleProperties = ("" + "" + self.GoalState.ContainerId + "" + "" + "" + self.GoalState.RoleInstanceId + "" + "" + "") a = self.HttpPostWithHeaders("/machine?comp=roleProperties", roleProperties) Log("Posted Role Properties. CertificateThumbprint=" + thumbprint) return a def LoadBalancerProbeServer_Shutdown(self): """ Shutdown the LoadBalancerProbeServer. """ if self.LoadBalancerProbeServer != None: self.LoadBalancerProbeServer.shutdown() self.LoadBalancerProbeServer = None def GenerateTransportCert(self): """ Create ssl certificate for https communication with endpoint server. """ Run(Openssl + " req -x509 -nodes -subj /CN=LinuxTransport -days 32768 -newkey rsa:2048 -keyout TransportPrivate.pem -out TransportCert.pem") cert = "" for line in GetFileContents("TransportCert.pem").split('\n'): if not "CERTIFICATE" in line: cert += line.rstrip() return cert def DoVmmStartup(self): """ Spawn the VMM startup script. """ Log("Starting Microsoft System Center VMM Initialization Process") pid = subprocess.Popen(["/bin/bash","/mnt/cdrom/secure/"+VMM_STARTUP_SCRIPT_NAME,"-p /mnt/cdrom/secure/ "]).pid time.sleep(5) sys.exit(0) def TryUnloadAtapiix(self): """ If global modloaded is True, then we loaded the ata_piix kernel module, unload it. """ if modloaded: Run("rmmod ata_piix.ko",chk_err=False) Log("Unloaded ata_piix.ko driver for ATAPI CD-ROM") def TryLoadAtapiix(self): """ Load the ata_piix kernel module if it exists. If successful, set global modloaded to True. If unable to load module leave modloaded False. """ global modloaded modloaded=False retcode,krn=RunGetOutput('uname -r') krn_pth='/lib/modules/'+krn.strip('\n')+'/kernel/drivers/ata/ata_piix.ko' if Run("lsmod | grep ata_piix",chk_err=False) == 0 : Log("Module " + krn_pth + " driver for ATAPI CD-ROM is already present.") return 0 if retcode: Error("Unable to provision: Failed to call uname -r") return "Unable to provision: Failed to call uname" if os.path.isfile(krn_pth): retcode,output=RunGetOutput("insmod " + krn_pth,chk_err=False) else: Log("Module " + krn_pth + " driver for ATAPI CD-ROM does not exist.") return 1 if retcode != 0: Error('Error calling insmod for '+ krn_pth + ' driver for ATAPI CD-ROM') return retcode time.sleep(1) # check 3 times if the mod is loaded for i in range(3): if Run('lsmod | grep ata_piix'): continue else : modloaded=True break if not modloaded: Error('Unable to load '+ krn_pth + ' driver for ATAPI CD-ROM') return 1 Log("Loaded " + krn_pth + " driver for ATAPI CD-ROM") # we have succeeded loading the ata_piix mod if it can be done. def SearchForVMMStartup(self): """ Search for a DVD/CDROM containing VMM's VMM_CONFIG_FILE_NAME. Call TryLoadAtapiix in case we must load the ata_piix module first. If VMM_CONFIG_FILE_NAME is found, call DoVmmStartup. Else, return to Azure Provisioning process. """ self.TryLoadAtapiix() if os.path.exists('/mnt/cdrom/secure') == False: CreateDir("/mnt/cdrom/secure", "root", 0700) mounted=False for dvds in [re.match(r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9]?)',x) for x in os.listdir('/dev/')]: if dvds == None: continue dvd = '/dev/'+dvds.group(0) if Run("LC_ALL=C fdisk -l " + dvd + " | grep Disk",chk_err=False): continue # Not mountable else: for retry in range(1,6): retcode,output=RunGetOutput("mount -v " + dvd + " /mnt/cdrom/secure") Log(output[:-1]) if retcode == 0: Log("mount succeeded on attempt #" + str(retry) ) mounted=True break if 'is already mounted on /mnt/cdrom/secure' in output: Log("Device " + dvd + " is already mounted on /mnt/cdrom/secure." + str(retry) ) mounted=True break Log("mount failed on attempt #" + str(retry) ) Log("mount loop sleeping 5...") time.sleep(5) if not mounted: # unable to mount continue if not os.path.isfile("/mnt/cdrom/secure/"+VMM_CONFIG_FILE_NAME): #nope - mount the next drive if mounted: Run("umount "+dvd,chk_err=False) mounted=False continue else : # it is the vmm startup self.DoVmmStartup() Log("VMM Init script not found. Provisioning for Azure") return def Provision(self): """ Responible for: Regenerate ssh keys, Mount, read, and parse ovfenv.xml from provisioning dvd rom Process the ovfenv.xml info Call ReportRoleProperties If configured, delete root password. Return None on success, error string on error. """ enabled = Config.get("Provisioning.Enabled") if enabled != None and enabled.lower().startswith("n"): return Log("Provisioning image started.") type = Config.get("Provisioning.SshHostKeyPairType") if type == None: type = "rsa" regenerateKeys = Config.get("Provisioning.RegenerateSshHostKeyPair") if regenerateKeys == None or regenerateKeys.lower().startswith("y"): Run("rm -f /etc/ssh/ssh_host_*key*") Run("ssh-keygen -N '' -t " + type + " -f /etc/ssh/ssh_host_" + type + "_key") MyDistro.restartSshService() #SetFileContents(LibDir + "/provisioned", "") dvd = None for dvds in [re.match(r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9]?)',x) for x in os.listdir('/dev/')]: if dvds == None : continue dvd = '/dev/'+dvds.group(0) if dvd == None: # No DVD device detected Error("No DVD device detected, unable to provision.") return "No DVD device detected, unable to provision." if MyDistro.mediaHasFilesystem(dvd) is False : out=MyDistro.load_ata_piix() if out: return out for i in range(10): # we may have to wait if os.path.exists(dvd): break Log("Waiting for DVD - sleeping 1 - "+str(i+1)+" try...") time.sleep(1) if os.path.exists('/mnt/cdrom/secure') == False: CreateDir("/mnt/cdrom/secure", "root", 0700) #begin mount loop - 5 tries - 5 sec wait between for retry in range(1,6): location='/mnt/cdrom/secure' retcode,output=MyDistro.mountDVD(dvd,location) Log(output[:-1]) if retcode == 0: Log("mount succeeded on attempt #" + str(retry) ) break if 'is already mounted on /mnt/cdrom/secure' in output: Log("Device " + dvd + " is already mounted on /mnt/cdrom/secure." + str(retry) ) break Log("mount failed on attempt #" + str(retry) ) Log("mount loop sleeping 5...") time.sleep(5) if not os.path.isfile("/mnt/cdrom/secure/ovf-env.xml"): Error("Unable to provision: Missing ovf-env.xml on DVD.") return "Failed to retrieve provisioning data (0x02)." ovfxml = (GetFileContents(u"/mnt/cdrom/secure/ovf-env.xml",asbin=False)) # use unicode here to ensure correct codec gets used. if ord(ovfxml[0]) > 128 and ord(ovfxml[1]) > 128 and ord(ovfxml[2]) > 128 : ovfxml = ovfxml[3:] # BOM is not stripped. First three bytes are > 128 and not unicode chars so we ignore them. ovfxml=ovfxml.strip(chr(0x00)) # we may have NULLs. ovfxml=ovfxml[ovfxml.find('.*?<", "*<", ovfxml)) Run("umount " + dvd,chk_err=False) MyDistro.unload_ata_piix() error = None if ovfxml != None: Log("Provisioning image using OVF settings in the DVD.") ovfobj = OvfEnv().Parse(ovfxml) if ovfobj != None: error = ovfobj.Process() if error : Error ("Provisioning image FAILED " + error) return ("Provisioning image FAILED " + error) Log("Ovf XML process finished") # This is done here because regenerated SSH host key pairs may be potentially overwritten when processing the ovfxml fingerprint = RunGetOutput("ssh-keygen -lf /etc/ssh/ssh_host_" + type + "_key.pub")[1].rstrip().split()[1].replace(':','') self.ReportRoleProperties(fingerprint) delRootPass = Config.get("Provisioning.DeleteRootPassword") if delRootPass != None and delRootPass.lower().startswith("y"): MyDistro.deleteRootPassword() Log("Provisioning image completed.") return error def Run(self): """ Called by 'waagent -daemon.' Main loop to process the goal state. State is posted every 25 seconds when provisioning has been completed. Search for VMM enviroment, start VMM script if found. Perform DHCP and endpoint server discovery by calling DoDhcpWork(). Check wire protocol versions. Set SCSI timeout on root device. Call GenerateTransportCert() to create ssl certs for server communication. Call UpdateGoalState(). If not provisioned, call ReportNotReady("Provisioning", "Starting") Call Provision(), set global provisioned = True if successful. Call goalState.Process() Start LBProbeServer if indicated in waagent.conf. Start the StateConsumer if indicated in waagent.conf. ReportReady if provisioning is complete. If provisioning failed, call ReportNotReady("ProvisioningFailed", provisionError) """ SetFileContents("/var/run/waagent.pid", str(os.getpid()) + "\n") # Determine if we are in VMM. Spawn VMM_STARTUP_SCRIPT_NAME if found. self.SearchForVMMStartup() ipv4='' while ipv4 == '' or ipv4 == '0.0.0.0' : ipv4=MyDistro.GetIpv4Address() if ipv4 == '' or ipv4 == '0.0.0.0' : Log("Waiting for network.") time.sleep(10) Log("IPv4 address: " + ipv4) mac='' mac=MyDistro.GetMacAddress() if len(mac)>0 : Log("MAC address: " + ":".join(["%02X" % Ord(a) for a in mac])) # Consume Entropy in ACPI table provided by Hyper-V try: SetFileContents("/dev/random", GetFileContents("/sys/firmware/acpi/tables/OEM0")) except: pass Log("Probing for Windows Azure environment.") self.Endpoint = self.DoDhcpWork() if self.Endpoint == None: Log("Windows Azure environment not detected.") while True: time.sleep(60) Log("Discovered Windows Azure endpoint: " + self.Endpoint) if not self.CheckVersions(): Error("Agent.CheckVersions failed") sys.exit(1) self.EnvMonitor = EnvMonitor() # Set SCSI timeout on SCSI disks MyDistro.initScsiDiskTimeout() global provisioned global provisionError global Openssl Openssl = Config.get("OS.OpensslPath") if Openssl == None: Openssl = "openssl" self.TransportCert = self.GenerateTransportCert() eventMonitor = None incarnation = None # goalStateIncarnationFromHealthReport currentPort = None # loadBalancerProbePort goalState = None # self.GoalState, instance of GoalState provisioned = os.path.exists(LibDir + "/provisioned") program = Config.get("Role.StateConsumer") provisionError = None lbProbeResponder = True setting = Config.get("LBProbeResponder") if setting != None and setting.lower().startswith("n"): lbProbeResponder = False while True: if (goalState == None) or (incarnation == None) or (goalState.Incarnation != incarnation): try: goalState = self.UpdateGoalState() except HttpResourceGoneError as e: Warn("Incarnation is out of date:{0}".format(e)) incarnation = None continue if goalState == None : Warn("Failed to fetch goalstate") continue if provisioned == False: self.ReportNotReady("Provisioning", "Starting") goalState.Process() if provisioned == False: provisionError = self.Provision() if provisionError == None : provisioned = True SetFileContents(LibDir + "/provisioned", "") lastCtime = "NOTFIND" try: walaConfigFile = MyDistro.getConfigurationPath() lastCtime = time.ctime(os.path.getctime(walaConfigFile)) except: pass #Get Ctime of wala config, can help identify the base image of this VM AddExtensionEvent(name="WALA",op=WALAEventOperation.Provision,isSuccess=True, message="WALA Config Ctime:"+lastCtime) executeCustomData = Config.get("Provisioning.ExecuteCustomData") if executeCustomData != None and executeCustomData.lower().startswith("y"): if os.path.exists(LibDir + '/CustomData'): Run('chmod +x ' + LibDir + '/CustomData') Run(LibDir + '/CustomData') else: Error(LibDir + '/CustomData does not exist.') # # only one port supported # restart server if new port is different than old port # stop server if no longer a port # goalPort = goalState.LoadBalancerProbePort if currentPort != goalPort: try: self.LoadBalancerProbeServer_Shutdown() currentPort = goalPort if currentPort != None and lbProbeResponder == True: self.LoadBalancerProbeServer = LoadBalancerProbeServer(currentPort) if self.LoadBalancerProbeServer == None : lbProbeResponder = False Log("Unable to create LBProbeResponder.") except Exception, e: Error("Failed to launch LBProbeResponder: {0}".format(e)) currentPort = None # Report SSH key fingerprint type = Config.get("Provisioning.SshHostKeyPairType") if type == None: type = "rsa" host_key_path = "/etc/ssh/ssh_host_" + type + "_key.pub" if(MyDistro.waitForSshHostKey(host_key_path)): fingerprint = RunGetOutput("ssh-keygen -lf /etc/ssh/ssh_host_" + type + "_key.pub")[1].rstrip().split()[1].replace(':','') self.ReportRoleProperties(fingerprint) if program != None and DiskActivated == True: try: Children.append(subprocess.Popen([program, "Ready"])) except OSError, e : ErrorWithPrefix('SharedConfig.Parse','Exception: '+ str(e) +' occured launching ' + program ) program = None sleepToReduceAccessDenied = 3 time.sleep(sleepToReduceAccessDenied) if provisionError != None: incarnation = self.ReportNotReady("ProvisioningFailed", provisionError) else: incarnation = self.ReportReady() # Process our extensions. if goalState.ExtensionsConfig == None and goalState.ExtensionsConfigXml != None : goalState.ExtensionsConfig = ExtensionsConfig().Parse(goalState.ExtensionsConfigXml) # report the status/heartbeat results of extension processing if goalState.ExtensionsConfig != None : goalState.ExtensionsConfig.ReportHandlerStatus() if not eventMonitor: eventMonitor = WALAEventMonitor(self.HttpPostWithHeaders) eventMonitor.StartEventsLoop() time.sleep(25 - sleepToReduceAccessDenied) WaagentLogrotate = """\ /var/log/waagent.log { monthly rotate 6 notifempty missingok } """ def GetMountPoint(mountlist, device): """ Example of mountlist: /dev/sda1 on / type ext4 (rw) proc on /proc type proc (rw) sysfs on /sys type sysfs (rw) devpts on /dev/pts type devpts (rw,gid=5,mode=620) tmpfs on /dev/shm type tmpfs (rw,rootcontext="system_u:object_r:tmpfs_t:s0") none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw) /dev/sdb1 on /mnt/resource type ext4 (rw) """ if (mountlist and device): for entry in mountlist.split('\n'): if(re.search(device, entry)): tokens = entry.split() #Return the 3rd column of this line return tokens[2] if len(tokens) > 2 else None return None def FindInLinuxKernelCmdline(option): """ Return match object if 'option' is present in the kernel boot options of the grub configuration. """ m=None matchs=r'^.*?'+MyDistro.grubKernelBootOptionsLine+r'.*?'+option+r'.*$' try: m=FindStringInFile(MyDistro.grubKernelBootOptionsFile,matchs) except IOError, e: Error('FindInLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str(e)) return m def AppendToLinuxKernelCmdline(option): """ Add 'option' to the kernel boot options of the grub configuration. """ if not FindInLinuxKernelCmdline(option): src=r'^(.*?'+MyDistro.grubKernelBootOptionsLine+r')(.*?)("?)$' rep=r'\1\2 '+ option + r'\3' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile,src,rep) except IOError, e : Error('AppendToLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str(e)) return 1 Run("update-grub",chk_err=False) return 0 def RemoveFromLinuxKernelCmdline(option): """ Remove 'option' to the kernel boot options of the grub configuration. """ if FindInLinuxKernelCmdline(option): src=r'^(.*?'+MyDistro.grubKernelBootOptionsLine+r'.*?)('+option+r')(.*?)("?)$' rep=r'\1\3\4' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile,src,rep) except IOError, e : Error('RemoveFromLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str(e)) return 1 Run("update-grub",chk_err=False) return 0 def FindStringInFile(fname,matchs): """ Return match object if found in file. """ try: ms=re.compile(matchs) for l in (open(fname,'r')).readlines(): m=re.search(ms,l) if m: return m except: raise return None def ReplaceStringInFile(fname,src,repl): """ Replace 'src' with 'repl' in file. """ try: sr=re.compile(src) if FindStringInFile(fname,src): updated='' for l in (open(fname,'r')).readlines(): n=re.sub(sr,repl,l) updated+=n ReplaceFileContentsAtomic(fname,updated) except : raise return def ApplyVNUMAWorkaround(): """ If kernel version has NUMA bug, add 'numa=off' to kernel boot options. """ VersionParts = platform.release().replace('-', '.').split('.') if int(VersionParts[0]) > 2: return if int(VersionParts[1]) > 6: return if int(VersionParts[2]) > 37: return if AppendToLinuxKernelCmdline("numa=off") == 0 : Log("Your kernel version " + platform.release() + " has a NUMA-related bug: NUMA has been disabled.") else : "Error adding 'numa=off'. NUMA has not been disabled." def RevertVNUMAWorkaround(): """ Remove 'numa=off' from kernel boot options. """ if RemoveFromLinuxKernelCmdline("numa=off") == 0 : Log('NUMA has been re-enabled') else : Log('NUMA has not been re-enabled') def Install(): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a) ) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", MyDistro.waagent_conf_file) SetFileContents("/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split('\n'))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") ApplyVNUMAWorkaround() return 0 def GetMyDistro(dist_class_name=''): """ Return MyDistro object. NOTE: Logging is not initialized at this point. """ if dist_class_name == '': if 'Linux' in platform.system(): Distro=DistInfo()[0] else : # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro=platform.system() Distro=Distro.strip('"') Distro=Distro.strip(' ') dist_class_name=Distro+'Distro' else: Distro=dist_class_name if not globals().has_key(dist_class_name): print Distro+' is not a supported distribution.' return None return globals()[dist_class_name]() # the distro class inside this module. def DistInfo(fullname=0): if 'FreeBSD' in platform.system(): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=fullname)) distinfo[0] = distinfo[0].strip() # remove trailing whitespace in distro name return distinfo else: return platform.dist() def PackagedInstall(buildroot): """ Called from setup.py for use by RPM. Generic implementation Creates directories and files /etc/waagent.conf, /etc/init.d/waagent, /usr/sbin/waagent, /etc/logrotate.d/waagent, /etc/sudoers.d/waagent under buildroot. Copies generated files waagent.conf, into place and exits. """ MyDistro=GetMyDistro() if MyDistro == None : sys.exit(1) MyDistro.packagedInstall(buildroot) def LibraryInstall(buildroot): pass def Uninstall(): """ Uninstall the agent service. Copy RulesFiles back to original locations. Delete agent-related files. Call RevertVNUMAWorkaround(). """ SwitchCwd() for a in RulesFiles: if os.path.isfile(GetLastPathElement(a)): try: shutil.move(GetLastPathElement(a), a) Warn("Moved " + LibDir + "/" + GetLastPathElement(a) + " -> " + a ) except: pass MyDistro.unregisterAgentService() MyDistro.uninstallDeleteFiles() RevertVNUMAWorkaround() return 0 def Deprovision(force, deluser): """ Remove user accounts created by provisioning. Disables root password if Provisioning.DeleteRootPassword = 'y' Stop agent service. Remove SSH host keys if they were generated by the provision. Set hostname to 'localhost.localdomain'. Delete cached system configuration files in /var/lib and /var/lib/waagent. """ #Append blank line at the end of file, so the ctime of this file is changed every time Run("echo ''>>"+ MyDistro.getConfigurationPath()) SwitchCwd() ovfxml = GetFileContents(LibDir+"/ovf-env.xml") ovfobj = None if ovfxml != None: ovfobj = OvfEnv().Parse(ovfxml, True) print("WARNING! The waagent service will be stopped.") print("WARNING! All SSH host key pairs will be deleted.") print("WARNING! Cached DHCP leases will be deleted.") MyDistro.deprovisionWarnUser() delRootPass = Config.get("Provisioning.DeleteRootPassword") if delRootPass != None and delRootPass.lower().startswith("y"): print("WARNING! root password will be disabled. You will not be able to login as root.") if ovfobj != None and deluser == True: print("WARNING! " + ovfobj.UserName + " account and entire home directory will be deleted.") if force == False and not raw_input('Do you want to proceed (y/n)? ').startswith('y'): return 1 MyDistro.stopAgentService() # Remove SSH host keys regenerateKeys = Config.get("Provisioning.RegenerateSshHostKeyPair") if regenerateKeys == None or regenerateKeys.lower().startswith("y"): Run("rm -f /etc/ssh/ssh_host_*key*") # Remove root password if delRootPass != None and delRootPass.lower().startswith("y"): MyDistro.deleteRootPassword() # Remove distribution specific networking configuration MyDistro.publishHostname('localhost.localdomain') MyDistro.deprovisionDeleteFiles() if deluser == True: MyDistro.DeleteAccount(ovfobj.UserName) return 0 def SwitchCwd(): """ Switch to cwd to /var/lib/waagent. Create if not present. """ CreateDir(LibDir, "root", 0700) os.chdir(LibDir) def Usage(): """ Print the arguments to waagent. """ print("usage: " + sys.argv[0] + " [-verbose] [-force] [-help|-install|-uninstall|-deprovision[+user]|-version|-serialconsole|-daemon]") return 0 def main(): """ Instantiate MyDistro, exit if distro class is not defined. Parse command-line arguments, exit with usage() on error. Instantiate ConfigurationProvider. Call appropriate non-daemon methods and exit. If daemon mode, enter Agent.Run() loop. """ if GuestAgentVersion == "": print("WARNING! This is a non-standard agent that does not include a valid version string.") if len(sys.argv) == 1: sys.exit(Usage()) LoggerInit('/var/log/waagent.log','/dev/console') global LinuxDistro LinuxDistro=DistInfo()[0] #The platform.py lib has issue with detecting oracle linux distribution. #Merge the following patch provided by oracle as a temparory fix. if os.path.exists("/etc/oracle-release"): LinuxDistro="Oracle Linux" global MyDistro MyDistro=GetMyDistro() if MyDistro == None : sys.exit(1) args = [] conf_file = None global force force = False for a in sys.argv[1:]: if re.match("^([-/]*)(help|usage|\?)", a): sys.exit(Usage()) elif re.match("^([-/]*)version", a): print(GuestAgentVersion + " running on " + LinuxDistro) sys.exit(0) elif re.match("^([-/]*)verbose", a): myLogger.verbose = True elif re.match("^([-/]*)force", a): force = True elif re.match("^(?:[-/]*)conf=.+", a): conf_file = re.match("^(?:[-/]*)conf=(.+)", a).groups()[0] elif re.match("^([-/]*)(setup|install)", a): sys.exit(MyDistro.Install()) elif re.match("^([-/]*)(uninstall)", a): sys.exit(Uninstall()) else: args.append(a) global Config Config = ConfigurationProvider(conf_file) logfile = Config.get("Logs.File") if logfile is not None: myLogger.file_path = logfile logconsole = Config.get("Logs.Console") if logconsole is not None and logconsole.lower().startswith("n"): myLogger.con_path = None verbose = Config.get("Logs.Verbose") if verbose != None and verbose.lower().startswith("y"): myLogger.verbose=True global daemon daemon = False for a in args: if re.match("^([-/]*)deprovision\+user", a): sys.exit(Deprovision(force, True)) elif re.match("^([-/]*)deprovision", a): sys.exit(Deprovision(force, False)) elif re.match("^([-/]*)daemon", a): daemon = True elif re.match("^([-/]*)serialconsole", a): AppendToLinuxKernelCmdline("console=ttyS0 earlyprintk=ttyS0") Log("Configured kernel to use ttyS0 as the boot console.") sys.exit(0) else: print("Invalid command line parameter:" + a) sys.exit(1) if daemon == False: sys.exit(Usage()) global modloaded modloaded = False try: SwitchCwd() Log(GuestAgentLongName + " Version: " + GuestAgentVersion) if IsLinux(): Log("Linux Distribution Detected : " + LinuxDistro) global WaAgent WaAgent = Agent() WaAgent.Run() except Exception, e: Error(traceback.format_exc()) Error("Exception: " + str(e)) sys.exit(1) if __name__ == '__main__' : main() ================================================ FILE: Common/WALinuxAgent-2.0.16/waagent ================================================ #!/usr/bin/env python # # Azure Linux Agent # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Implements parts of RFC 2131, 1541, 1497 and # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx # http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx import sys # crypt module was removed in Python 3.13 # For Python < 3.11: use builtin crypt # For Python >= 3.11: try crypt_r package, then ctypes fallback if sys.version_info >= (3, 11): try: import crypt_r as crypt except ImportError: try: from Utils import crypt_fallback as crypt except ImportError: crypt = None else: try: import crypt except ImportError: try: from Utils import crypt_fallback as crypt except ImportError: crypt = None import random import array import base64 import os import os.path import platform import pwd import re import shutil import socket import struct import string import subprocess import sys import tempfile import textwrap import threading import time import traceback import xml.dom.minidom import fcntl import inspect import zipfile import json import datetime import xml.sax.saxutils # distutils.version was deprecated in Python 3.10 and removed in Python 3.12 def import_loose_version(): if sys.version_info >= (3, 12): return LooseVersionComparator else: from distutils.version import LooseVersion return LooseVersion if sys.version_info[0] == 3: import http.client as httpclient from urllib.parse import urlparse elif sys.version_info[0] == 2: import httplib as httpclient from urlparse import urlparse if not hasattr(subprocess, 'check_output'): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output = check_output subprocess.CalledProcessError = CalledProcessError GuestAgentName = "WALinuxAgent" GuestAgentLongName = "Azure Linux Agent" GuestAgentVersion = "WALinuxAgent-2.0.16" ProtocolVersion = "2012-11-30" # WARNING this value is used to confirm the correct fabric protocol. Config = None WaAgent = None DiskActivated = False Openssl = "openssl" Children = [] ExtensionChildren = [] VMM_STARTUP_SCRIPT_NAME = 'install' VMM_CONFIG_FILE_NAME = 'linuxosconfiguration.xml' global RulesFiles RulesFiles = ["/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules"] VarLibDhcpDirectories = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"] EtcDhcpClientConfFiles = ["/etc/dhcp/dhclient.conf", "/etc/dhcp3/dhclient.conf"] global LibDir LibDir = "/var/lib/waagent" global provisioned provisioned = False global provisionError provisionError = None HandlerStatusToAggStatus = {"installed": "Installing", "enabled": "Ready", "unintalled": "NotReady", "disabled": "NotReady"} WaagentConf = """\ # # Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ext4 # Typically ext3 or ext4. FreeBSD images should use 'ufs2' here. ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ README_FILENAME = "DATALOSS_WARNING_README.txt" README_FILECONTENT = """\ WARNING: THIS IS A TEMPORARY DISK. Any data stored on this drive is SUBJECT TO LOSS and THERE IS NO WAY TO RECOVER IT. Please do not use this disk for storing any personal or application data. For additional details to please refer to the MSDN documentation at : http://msdn.microsoft.com/en-us/library/windowsazure/jj672979.aspx """ class LooseVersionComparator: """ Class to parse and compare versions with pre-release tags Based on LooseVersion from distutils.version as that was removed in python 3.12. Python's standard library does not include a direct replacement for `distutils.version.LooseVersion` or `StrictVersion` for arbitrary version string comparison outside of the context of installed packages. This is needed to avoid installing pip. """ def __init__(self, version): self.version, self.prerelease = self.parse_version(version) def parse_version(self, version): # Regular expression to parse versions with pre-release tags match = re.match(r'^(\d+(?:\.\d+)*)(?:-([\da-zA-Z-]+))?$', version) if not match: raise ValueError("Invalid version format: {0}".format(version)) main_version = tuple(map(int, match.group(1).split('.'))) prerelease = match.group(2) return main_version, prerelease def __lt__(self, other): if self.version == other.version: return self._compare_prerelease(self.prerelease, other.prerelease) < 0 return self.version < other.version def __gt__(self, other): if self.version == other.version: return self._compare_prerelease(self.prerelease, other.prerelease) > 0 return self.version > other.version def __eq__(self, other): return self.version == other.version and self.prerelease == other.prerelease def __str__(self): return ".".join(map(str, self.version)) + ("-{0}".format(self.prerelease) if self.prerelease else "") @staticmethod def _compare_prerelease(pr1, pr2): if pr1 is None and pr2 is None: return 0 if pr1 is None: return 1 if pr2 is None: return -1 return (pr1 > pr2) - (pr1 < pr2) LooseVersion = import_loose_version() ############################################################ # BEGIN DISTRO CLASS DEFS ############################################################ ############################################################ # AbstractDistro ############################################################ class AbstractDistro(object): """ AbstractDistro defines a skeleton neccesary for a concrete Distro class. Generic methods and attributes are kept here, distribution specific attributes and behavior are to be placed in the concrete child named distroDistro, where distro is the string returned by calling python platform.linux_distribution()[0]. So for CentOS the derived class is called 'centosDistro'. """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux = None self.service_cmd = '/usr/sbin/service' self.ssh_service_restart_option = 'restart' self.ssh_service_name = 'ssh' self.ssh_config_file = '/etc/ssh/sshd_config' self.hostname_file_path = '/etc/hostname' self.dhcp_client_name = 'dhclient' self.requiredDeps = ['route', 'shutdown', 'ssh-keygen', 'useradd', 'usermod', 'openssl', 'sfdisk', 'fdisk', 'mkfs', 'sed', 'grep', 'sudo', 'parted'] self.init_script_file = '/etc/init.d/waagent' self.agent_package_name = 'WALinuxAgent' self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log", '/etc/resolv.conf'] self.agent_files_to_uninstall = ["/etc/waagent.conf", "/etc/logrotate.d/waagent"] self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX_DEFAULT=' self.getpidcmd = 'pidof' self.mount_dvd_cmd = 'mount' self.sudoers_dir_base = '/etc' self.waagent_conf_file = WaagentConf self.shadow_file_mode = 0o600 self.shadow_file_path = "/etc/shadow" self.dhcp_enabled = False def isSelinuxSystem(self): """ Checks and sets self.selinux = True if SELinux is available on system. """ if self.selinux == None: if Run("which getenforce", chk_err=False): self.selinux = False else: self.selinux = True return self.selinux def isSelinuxRunning(self): """ Calls shell command 'getenforce' and returns True if 'Enforcing'. """ if self.isSelinuxSystem(): return RunGetOutput("getenforce")[1].startswith("Enforcing") else: return False def setSelinuxEnforce(self, state): """ Calls shell command 'setenforce' with 'state' and returns resulting exit code. """ if self.isSelinuxSystem(): if state: s = '1' else: s = '0' return Run("setenforce " + s) def setSelinuxContext(self, path, cn): """ Calls shell 'chcon' with 'path' and 'cn' context. Returns exit result. """ if self.isSelinuxSystem(): return Run('chcon ' + cn + ' ' + path) def setHostname(self, name): """ Shell call to hostname. Returns resulting exit code. """ return Run('hostname ' + name) def publishHostname(self, name): """ Set the contents of the hostname file to 'name'. Return 1 on failure. """ try: r = SetFileContents(self.hostname_file_path, name) for f in EtcDhcpClientConfFiles: if os.path.exists(f) and FindStringInFile(f, r'^[^#]*?send\s*host-name.*?(|gethostname[(,)])') == None: r = ReplaceFileContentsAtomic('/etc/dhcp/dhclient.conf', "send host-name \"" + name + "\";\n" + "\n".join(filter(lambda a: not a.startswith("send host-name"), GetFileContents('/etc/dhcp/dhclient.conf').split( '\n')))) except: return 1 return r def installAgentServiceScriptFiles(self): """ Create the waagent support files for service installation. Called by registerAgentService() Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def registerAgentService(self): """ Calls installAgentService to create service files. Shell exec service registration commands. (e.g. chkconfig --add waagent) Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def uninstallAgentService(self): """ Call service subsystem to remove waagent script. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' start') def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' stop', False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_name + " " + self.ssh_service_restart_option retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def sshDeployPublicKey(self, fprint, path): """ Generic sshDeployPublicKey - over-ridden in some concrete Distro classes due to minor differences in openssl packages deployed """ error = 0 SshPubKey = OvfEnv().OpensslToSsh(fprint) if SshPubKey != None: AppendFileContents(path, SshPubKey) else: Error("Failed: " + fprint + ".crt -> " + path) error = 1 return error def checkPackageInstalled(self, p): """ Query package database for prescence of an installed package. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def checkPackageUpdateable(self, p): """ Online check if updated package of walinuxagent is available. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def deleteRootPassword(self): """ Generic root password removal. """ filepath = "/etc/shadow" ReplaceFileContentsAtomic(filepath, "root:*LOCK*:14600::::::\n" + "\n".join( filter(lambda a: not a.startswith("root:"), GetFileContents(filepath).split('\n')))) os.chmod(filepath, self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath, 'system_u:object_r:shadow_t:s0') Log("Root password deleted.") return 0 def changePass(self, user, password): Log("Change user password") crypt_id = Config.get("Provisioning.PasswordCryptId") if crypt_id is None: crypt_id = "6" salt_len = Config.get("Provisioning.PasswordCryptSaltLength") try: salt_len = int(salt_len) if salt_len < 0 or salt_len > 10: salt_len = 10 except (ValueError, TypeError): salt_len = 10 return self.chpasswd(user, password, crypt_id=crypt_id, salt_len=salt_len) def chpasswd(self, username, password, crypt_id=6, salt_len=10): passwd_hash = self.gen_password_hash(password, crypt_id, salt_len) cmd = "usermod -p '{0}' {1}".format(passwd_hash, username) ret, output = RunGetOutput(cmd, log_cmd=False) if ret != 0: return "Failed to set password for {0}: {1}".format(username, output) def gen_password_hash(self, password, crypt_id, salt_len): collection = string.ascii_letters + string.digits salt = ''.join(random.choice(collection) for _ in range(salt_len)) salt = "${0}${1}".format(crypt_id, salt) return crypt.crypt(password, salt) def load_ata_piix(self): return WaAgent.TryLoadAtapiix() def unload_ata_piix(self): """ Generic function to remove ata_piix.ko. """ return WaAgent.TryUnloadAtapiix() def deprovisionWarnUser(self): """ Generic user warnings used at deprovision. """ print("WARNING! Nameserver configuration in /etc/resolv.conf will be deleted.") def deprovisionDeleteFiles(self): """ Files to delete when VM is deprovisioned """ for a in VarLibDhcpDirectories: Run("rm -f " + a + "/*") # Clear LibDir, remove nameserver and root bash history for f in os.listdir(LibDir) + self.fileBlackList: try: os.remove(f) except: pass return 0 def uninstallDeleteFiles(self): """ Files to delete when agent is uninstalled. """ for f in self.agent_files_to_uninstall: try: os.remove(f) except: pass return 0 def checkDependencies(self): """ Generic dependency check. Return 1 unless all dependencies are satisfied. """ if self.checkPackageInstalled('NetworkManager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 try: m = __import__('pyasn1') except ImportError: Error(GuestAgentLongName + " requires python-pyasn1 for your Linux distribution.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self, buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot + '/etc'): os.mkdir(buildroot + '/etc') SetFileContents(buildroot + '/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot + '/etc/logrotate.d'): os.mkdir(buildroot + '/etc/logrotate.d') SetFileContents(buildroot + '/etc/logrotate.d/waagent', WaagentLogrotate) self.init_script_file = buildroot + self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def GetIpv4Address(self): """ Return the ip of the first active non-loopback interface. """ addr = '' iface, addr = GetFirstActiveNetworkInterfaceNonLoopback() return addr def GetMacAddress(self): return GetMacAddress() def GetInterfaceName(self): return GetFirstActiveNetworkInterfaceNonLoopback()[0] def RestartInterface(self, iface, max_retry=3): for retry in range(1, max_retry + 1): ret = Run("ifdown " + iface + " && ifup " + iface) if ret == 0: return Log("Failed to restart interface: {0}, ret={1}".format(iface, ret)) if retry < max_retry: Log("Retry restart interface in 5 seconds") time.sleep(5) def CreateAccount(self, user, password, expiration, thumbprint): return CreateAccount(user, password, expiration, thumbprint) def DeleteAccount(self, user): return DeleteAccount(user) def ActivateResourceDisk(self): """ Format, mount, and if specified in the configuration set resource disk as swap. """ global DiskActivated format = Config.get("ResourceDisk.Format") if format == None or format.lower().startswith("n"): DiskActivated = True return device = DeviceForIdePort(1) if device == None: Error("ActivateResourceDisk: Unable to detect disk topology.") return device = "/dev/" + device mountlist = RunGetOutput("mount")[1] mountpoint = GetMountPoint(mountlist, device) if (mountpoint): Log("ActivateResourceDisk: " + device + "1 is already mounted.") else: mountpoint = Config.get("ResourceDisk.MountPoint") if mountpoint == None: mountpoint = "/mnt/resource" CreateDir(mountpoint, "root", 0o755) fs = Config.get("ResourceDisk.Filesystem") if fs == None: fs = "ext3" partition = device + "1" # Check partition type Log("Detect GPT...") ret = RunGetOutput("parted {0} print".format(device)) if ret[0] == 0 and "gpt" in ret[1]: Log("GPT detected.") # GPT(Guid Partition Table) is used. # Get partitions. parts = filter(lambda x: re.match(r"^\s*[0-9]+", x), ret[1].split("\n")) # If there are more than 1 partitions, remove all partitions # and create a new one using the entire disk space. if len(parts) > 1: for i in range(1, len(parts) + 1): Run("parted {0} rm {1}".format(device, i)) Run("parted {0} mkpart primary 0% 100%".format(device)) Run("mkfs." + fs + " " + partition + " -F") else: existingFS = RunGetOutput("sfdisk -q -c " + device + " 1", chk_err=False)[1].rstrip() if existingFS == "7" and fs != "ntfs": Run("sfdisk -c " + device + " 1 83") Run("mkfs." + fs + " " + partition) if Run("mount " + partition + " " + mountpoint, chk_err=False): # If mount failed, try to format the partition and mount again Warn("Failed to mount resource disk. Retry mounting.") Run("mkfs." + fs + " " + partition + " -F") if Run("mount " + partition + " " + mountpoint): Error("ActivateResourceDisk: Failed to mount resource disk (" + partition + ").") return Log("Resource disk (" + partition + ") is mounted at " + mountpoint + " with fstype " + fs) # Create README file under the root of resource disk SetFileContents(os.path.join(mountpoint, README_FILENAME), README_FILECONTENT) DiskActivated = True # Create swap space swap = Config.get("ResourceDisk.EnableSwap") if swap == None or swap.lower().startswith("n"): return sizeKB = int(Config.get("ResourceDisk.SwapSizeMB")) * 1024 if os.path.isfile(mountpoint + "/swapfile") and os.path.getsize(mountpoint + "/swapfile") != (sizeKB * 1024): os.remove(mountpoint + "/swapfile") if not os.path.isfile(mountpoint + "/swapfile"): Run("dd if=/dev/zero of=" + mountpoint + "/swapfile bs=1024 count=" + str(sizeKB)) Run("mkswap " + mountpoint + "/swapfile") Run("chmod 600 " + mountpoint + "/swapfile") if not Run("swapon " + mountpoint + "/swapfile"): Log("Enabled " + str(sizeKB) + " KB of swap at " + mountpoint + "/swapfile") else: Error("ActivateResourceDisk: Failed to activate swap at " + mountpoint + "/swapfile") def Install(self): return Install() def mediaHasFilesystem(self, dsk): if len(dsk) == 0: return False if Run("LC_ALL=C fdisk -l " + dsk + " | grep Disk"): return False return True def mountDVD(self, dvd, location): return RunGetOutput(self.mount_dvd_cmd + ' ' + dvd + ' ' + location) def GetHome(self): return GetHome() def getDhcpClientName(self): return self.dhcp_client_name def initScsiDiskTimeout(self): """ Set the SCSI disk timeout when the agent starts running """ self.setScsiDiskTimeout() def setScsiDiskTimeout(self): """ Iterate all SCSI disks(include hot-add) and set their timeout if their value are different from the OS.RootDeviceScsiTimeout """ try: scsiTimeout = Config.get("OS.RootDeviceScsiTimeout") for diskName in [disk for disk in os.listdir("/sys/block") if disk.startswith("sd")]: self.setBlockDeviceTimeout(diskName, scsiTimeout) except: pass def setBlockDeviceTimeout(self, device, timeout): """ Set SCSI disk timeout by set /sys/block/sd*/device/timeout """ if timeout != None and device: filePath = "/sys/block/" + device + "/device/timeout" if (GetFileContents(filePath).splitlines()[0].rstrip() != timeout): SetFileContents(filePath, timeout) Log("SetBlockDeviceTimeout: Update the device " + device + " with timeout " + timeout) def waitForSshHostKey(self, path): """ Provide a dummy waiting, since by default, ssh host key is created by waagent and the key should already been created. """ if (os.path.isfile(path)): return True else: Error("Can't find host key: {0}".format(path)) return False def isDHCPEnabled(self): return self.dhcp_enabled def stopDHCP(self): """ Stop the system DHCP client so that the agent can bind on its port. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('stopDHCP method missing') def startDHCP(self): """ Start the system DHCP client. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('startDHCP method missing') def translateCustomData(self, data): """ Translate the custom data from a Base64 encoding. Default to no-op. """ decodeCustomData = Config.get("Provisioning.DecodeCustomData") if decodeCustomData != None and decodeCustomData.lower().startswith("y"): return base64.b64decode(data) return data def getConfigurationPath(self): return "/etc/waagent.conf" def getProcessorCores(self): return int(RunGetOutput("grep 'processor.*:' /proc/cpuinfo |wc -l")[1]) def getTotalMemory(self): return int(RunGetOutput("grep MemTotal /proc/meminfo |awk '{print $2}'")[1]) / 1024 def getInterfaceNameByMac(self, mac): ret, output = RunGetOutput("ifconfig -a") if ret != 0: raise Exception("Failed to get network interface info") output = output.replace('\n', '') match = re.search(r"(eth\d).*(HWaddr|ether) {0}".format(mac), output, re.IGNORECASE) if match is None: raise Exception("Failed to get ifname with mac: {0}".format(mac)) output = match.group(0) eths = re.findall(r"eth\d", output) if eths is None or len(eths) == 0: raise Exception("Failed to get ifname with mac: {0}".format(mac)) return eths[-1] def configIpV4(self, ifName, addr, netmask=24): ret, output = RunGetOutput("ifconfig {0} up".format(ifName)) if ret != 0: raise Exception("Failed to bring up {0}: {1}".format(ifName, output)) ret, output = RunGetOutput("ifconfig {0} {1}/{2}".format(ifName, addr, netmask)) if ret != 0: raise Exception("Failed to config ipv4 for {0}: {1}".format(ifName, output)) def setDefaultGateway(self, gateway): Run("/sbin/route add default gw" + gateway, chk_err=False) def routeAdd(self, net, mask, gateway): Run("/sbin/route add -net " + net + " netmask " + mask + " gw " + gateway, chk_err=False) def getNdDriverVersion(self): """ if error happens, raise a RdmaError """ try: with open("/var/lib/hyperv/.kvp_pool_0", "r") as f: lines = f.read() r = re.search(r"NdDriverVersion\0+(\d\d\d\.\d)", lines) if r is not None: NdDriverVersion = r.groups()[0] return NdDriverVersion # e.g. NdDriverVersion = 142.0 else: Log("Error: NdDriverVersion not found.") return None except Exception as e: errMsg = 'Cannot update status: Failed to enable the extension with error: %s, stack trace: %s' % ( str(e), traceback.format_exc()) Log(errMsg) raise RdmaError(RdmaConfig.nd_driver_detect_error) def checkInstallHyperV(self): return None def getRdmaPackageVersion(self): return None def rdmaUpdate(self, updateRdmaRepository=None): Log("rdmaUpdate in base class") pass def checkRDMA(self): Log("checkRDMA in base class") pass class DefaultDistro(AbstractDistro): """ Default Distro concrete class: This class serves as a default OS behavior class. """ def startDHCP(self): """ Following the pattern used in WALinuxAgent for Default distro: This method is not implemented in the default case. """ pass def stopDHCP(self): """ Following the pattern used in WALinuxAgent for Default distro: This method is not implemented in the default case. """ pass def __init__(self): super(DefaultDistro, self).__init__() ############################################################ # GentooDistro ############################################################ gentoo_init_file = """\ #!/sbin/runscript command=/usr/sbin/waagent pidfile=/var/run/waagent.pid command_args=-daemon command_background=true name="Azure Linux Agent" depend() { need localmount use logger network after bootmisc modules } """ class gentooDistro(AbstractDistro): """ Gentoo distro concrete class """ def __init__(self): # super(gentooDistro, self).__init__() self.service_cmd = '/sbin/service' self.ssh_service_name = 'sshd' self.hostname_file_path = '/etc/conf.d/hostname' self.dhcp_client_name = 'dhcpcd' self.shadow_file_mode = 0o640 self.init_file = gentoo_init_file def publishHostname(self, name): try: if (os.path.isfile(self.hostname_file_path)): r = ReplaceFileContentsAtomic(self.hostname_file_path, "hostname=\"" + name + "\"\n" + "\n".join(filter(lambda a: not a.startswith("hostname="), GetFileContents(self.hostname_file_path).split("\n")))) except: return 1 return r def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o755) def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('rc-update add ' + self.agent_service_name + ' default') def uninstallAgentService(self): return Run('rc-update del ' + self.agent_service_name + ' default') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self, p): if Run('eix -I ^' + p + '$', chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run('eix -u ^' + p + '$', chk_err=False): return 0 else: return 1 def RestartInterface(self, iface): Run("/etc/init.d/net." + iface + " restart") ############################################################ # SuSEDistro ############################################################ suse_init_file = """\ #! /bin/sh # # Azure Linux Agent sysV init script # # Copyright 2013 Microsoft Corporation # Copyright SUSE LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # /etc/init.d/waagent # # and symbolic link # # /usr/sbin/rcwaagent # # System startup script for the waagent # ### BEGIN INIT INFO # Provides: AzureLinuxAgent # Required-Start: $network sshd # Required-Stop: $network sshd # Default-Start: 3 5 # Default-Stop: 0 1 2 6 # Description: Start the AzureLinuxAgent ### END INIT INFO PYTHON=/usr/bin/python WAZD_BIN=/usr/sbin/waagent WAZD_CONF=/etc/waagent.conf WAZD_PIDFILE=/var/run/waagent.pid test -x "$WAZD_BIN" || { echo "$WAZD_BIN not installed"; exit 5; } test -e "$WAZD_CONF" || { echo "$WAZD_CONF not found"; exit 6; } . /etc/rc.status # First reset status of this service rc_reset # Return values acc. to LSB for all commands but status: # 0 - success # 1 - misc error # 2 - invalid or excess args # 3 - unimplemented feature (e.g. reload) # 4 - insufficient privilege # 5 - program not installed # 6 - program not configured # # Note that starting an already running service, stopping # or restarting a not-running service as well as the restart # with force-reload (in case signalling is not supported) are # considered a success. case "$1" in start) echo -n "Starting AzureLinuxAgent" ## Start daemon with startproc(8). If this fails ## the echo return value is set appropriate. startproc -f ${PYTHON} ${WAZD_BIN} -daemon rc_status -v ;; stop) echo -n "Shutting down AzureLinuxAgent" ## Stop daemon with killproc(8) and if this fails ## set echo the echo return value. killproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; try-restart) ## Stop the service and if this succeeds (i.e. the ## service was running before), start it again. $0 status >/dev/null && $0 restart rc_status ;; restart) ## Stop the service and regardless of whether it was ## running or not, start it again. $0 stop sleep 1 $0 start rc_status ;; force-reload|reload) rc_status ;; status) echo -n "Checking for service AzureLinuxAgent " ## Check status with checkproc(8), if process is running ## checkproc will return with exit status 0. checkproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; probe) ;; *) echo "Usage: $0 {start|stop|status|try-restart|restart|force-reload|reload}" exit 1 ;; esac rc_exit """ class SuSEDistro(AbstractDistro): """ SuSE Distro concrete class Put SuSE specific behavior here... """ def __init__(self): super(SuSEDistro, self).__init__() dist_info = DistInfo() dist_info_fullname = DistInfo(fullname=1) self.dhcp_client_name = 'dhcpcd' if ((dist_info_fullname[0] == 'SUSE Linux Enterprise Server' and dist_info[1] >= '12') or \ (dist_info_fullname[0] == 'openSUSE' and dist_info[1] >= '13.2')): self.dhcp_client_name = 'wickedd-dhcp4' self.dhcp_enabled = True self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' self.getpidcmd = 'pidof ' self.hostname_file_path = '/etc/HOSTNAME' self.init_file = suse_init_file self.kernel_boot_options_file = '/boot/grub/menu.lst' self.modprobe_path = '/usr/bin/modprobe' self.requiredDeps += ["/sbin/insserv"] self.reboot_path = '/sbin/reboot' self.rpm_path = '/bin/rpm' self.service_cmd = '/sbin/service' self.ssh_service_name = 'sshd' if (dist_info[1] == "11"): self.ps_path = '/bin/ps' else: self.ps_path = '/usr/bin/ps' self.zypper_path = '/usr/bin/zypper' def checkPackageInstalled(self, p): if Run("rpm -q " + p, chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run("zypper list-updates | grep " + p, chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) except: pass def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('insserv ' + self.agent_service_name) def uninstallAgentService(self): return Run('insserv -r ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def startDHCP(self): Run("service " + self.dhcp_client_name + " start", chk_err=False) def stopDHCP(self): Run("service " + self.dhcp_client_name + " stop", chk_err=False) def getRdmaPackageVersion(self): """ """ error, output = RunGetOutput(self.zypper_path + " info " + RdmaConfig.rmda_package_name) if (error == RdmaConfig.process_success): r = re.search(r"Version: (\S+)", output) if r is not None: package_version = r.groups()[0] # e.g. package_version is "20150707.140.0_k3.12.28_4-3.1." return package_version else: return None else: return None def checkInstallHyperV(self): error, output = RunGetOutput(self.ps_path + " -ef") if (error != RdmaConfig.process_success): return RdmaConfig.common_failed else: hv_kvp_daemon_service_process_name = "hv_kvp_daemon" hv_kvp_daemon_service_name = "hv_kvp_daemon" r = re.search(hv_kvp_daemon_service_process_name, output) if r is None: # if the Log("hv kvp daemon is not running.") error, output = RunGetOutput(self.rpm_path + " -q hyper-v", chk_err=False, log_cmd=False) if (error == RdmaConfig.process_success): Log("the hyper-v package is installed, but hv_kvp_daemon not started") return RdmaConfig.hv_kvp_daemon_not_started else: error, output = RunGetOutput(self.zypper_path + " -n install --force hyper-v") Log("install hyper-v return code: " + str(error) + " output:" + str(output)) if (error != RdmaConfig.process_success): return RdmaConfig.common_failed else: self.rebootMachine() return RdmaConfig.process_success else: Log("KVP daemon is running") return RdmaConfig.process_success def rdmaUpdate(self, updateRdmaRepository=None): # give some time for the hv_hvp_daemon to start up. time.sleep(10) check_install_result = self.checkInstallHyperV() if (check_install_result == RdmaConfig.process_success): # wait for sometime the RDMA Driver not passed in by KVP in time. time.sleep(10) nd_driver_version = self.getNdDriverVersion() if (nd_driver_version is None): raise RdmaError(RdmaConfig.driver_version_not_found) else: check_result = self.checkRDMA(nd_driver_version=nd_driver_version) Log("RDMA version check result is " + str(check_result)) if (check_result == RdmaConfig.UpToDate): return elif (check_result == RdmaConfig.OutOfDate): update_rdma_driver_result = self.rdmaUpdatePackage(host_version=nd_driver_version, updateRdmaRepository=updateRdmaRepository) elif (check_result == RdmaConfig.DriverVersionNotFound): raise RdmaError(RdmaConfig.driver_version_not_found) elif (check_result == RdmaConfig.Unknown): raise RdmaError(RdmaConfig.unknown_error) else: raise RdmaError(RdmaConfig.check_install_hv_utils_failed) def rdmaUpdatePackage(self, host_version, updateRdmaRepository=None): # check the repository first if (updateRdmaRepository is not None): error, output = RunGetOutput(self.zypper_path + " lr -u") rdma_pack_repository_name = "msft-rdma-pack" rdma_pack_result = re.search(rdma_pack_repository_name, output) if rdma_pack_result is None: Log("rdma_pack_result is None") error, output = RunGetOutput( self.zypper_path + " ar " + str(updateRdmaRepository) + " " + rdma_pack_repository_name) # wait for the cache build. time.sleep(20) Log("error result is " + str(error) + " output is : " + str(output)) else: Log("output is: " + str(output)) Log("msft-rdma-pack found") returnCode, message = RunGetOutput(self.zypper_path + " --no-gpg-checks refresh") Log("refresh repo return code is " + str(returnCode) + " output is: " + str(message)) # install the wrapper package, that will put the driver RPM packages under /opt/microsoft/rdma returnCode, message = RunGetOutput(self.zypper_path + " -n remove " + RdmaConfig.wrapper_package_name) Log("remove wrapper package return code is " + str(returnCode) + " output is: " + str(message)) returnCode, message = RunGetOutput( self.zypper_path + " --non-interactive install --force " + RdmaConfig.wrapper_package_name) Log("install wrapper package return code is " + str(returnCode) + " output is: " + str(message)) r = os.listdir("/opt/microsoft/rdma") if r is not None: for filename in r: if re.match(RdmaConfig.rmda_package_name + r"-\d{8}\.(%s).+" % host_version, filename): error, output = RunGetOutput( self.zypper_path + " --non-interactive remove " + RdmaConfig.rmda_package_name) Log("remove rdma package result is " + str(error) + " output is: " + str(output)) Log("Installing RPM /opt/microsoft/rdma/" + filename) error, output = RunGetOutput( self.zypper_path + " --non-interactive install --force /opt/microsoft/rdma/%s" % filename) Log("Install rdma package result is " + str(error) + " output is: " + str(output)) if (error == RdmaConfig.process_success): self.rebootMachine() else: raise RdmaError(RdmaConfig.package_install_failed) else: Log("RDMA drivers not found in /opt/microsoft/rdma") raise RdmaError(RdmaConfig.package_not_found) def checkRDMA(self, nd_driver_version=None): if (nd_driver_version is None): nd_driver_version = self.getNdDriverVersion() if (nd_driver_version is None or nd_driver_version == ""): return RdmaConfig.DriverVersionNotFound package_version = self.getRdmaPackageVersion() if (package_version is None or package_version == ""): return RdmaConfig.OutOfDate else: # package_version would be like this :20150707_k3.12.28_4-3.1 20150707.140.0_k3.12.28_4-1.1 # nd_driver_version 140.0 Log("nd_driver_version is " + str(nd_driver_version) + " package_version is " + str(package_version)) if (nd_driver_version is not None): r = re.match("^[0-9]+[.](%s).+" % nd_driver_version, package_version) # NdDriverVersion should be at the end of package version if not r: # host ND version is the same as the package version, do an update return RdmaConfig.OutOfDate else: return RdmaConfig.UpToDate return RdmaConfig.Unknown def rebootMachine(self): Log("rebooting the machine") RunGetOutput(self.reboot_path) ############################################################ # redhatDistro ############################################################ redhat_init_file = """\ #!/bin/bash # # Init file for AzureLinuxAgent. # # chkconfig: 2345 60 80 # description: AzureLinuxAgent # # source function library . /etc/rc.d/init.d/functions RETVAL=0 FriendlyName="AzureLinuxAgent" WAZD_BIN=/usr/sbin/waagent start() { echo -n $"Starting $FriendlyName: " $WAZD_BIN -daemon & } stop() { echo -n $"Stopping $FriendlyName: " killproc -p /var/run/waagent.pid $WAZD_BIN RETVAL=$? echo return $RETVAL } case "$1" in start) start ;; stop) stop ;; restart) stop start ;; reload) ;; report) ;; status) status $WAZD_BIN RETVAL=$? ;; *) echo $"Usage: $0 {start|stop|restart|status}" RETVAL=1 esac exit $RETVAL """ class redhatDistro(AbstractDistro): """ Redhat Distro concrete class Put Redhat specific behavior here... """ def __init__(self): super(redhatDistro, self).__init__() self.service_cmd = '/sbin/service' self.ssh_service_restart_option = 'condrestart' self.ssh_service_name = 'sshd' self.hostname_file_path = None if DistInfo()[1] < '7.0' else '/etc/hostname' self.init_file = redhat_init_file self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' def publishHostname(self, name): super(redhatDistro, self).publishHostname(name) if DistInfo()[1] < '7.0': filepath = "/etc/sysconfig/network" if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "HOSTNAME=" + name + "\n" + "\n".join( filter(lambda a: not a.startswith("HOSTNAME"), GetFileContents(filepath).split('\n')))) ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join( filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n')))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('chkconfig --add waagent') def uninstallAgentService(self): return Run('chkconfig --del ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self, p): if Run("yum list installed " + p, chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run("yum check-update | grep " + p, chk_err=False): return 1 else: return 0 def checkDependencies(self): """ Generic dependency check. Return 1 unless all dependencies are satisfied. """ if DistInfo()[1] < '7.0' and self.checkPackageInstalled('NetworkManager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 try: m = __import__('pyasn1') except ImportError: Error(GuestAgentLongName + " requires python-pyasn1 for your Linux distribution.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 ############################################################ # centosDistro ############################################################ class centosDistro(redhatDistro): """ CentOS Distro concrete class Put CentOS specific behavior here... """ def __init__(self): super(centosDistro, self).__init__() def rdmaUpdate(self, updateRdmaRepository=None): pass def checkRDMA(self): pass ############################################################ # oracleDistro ############################################################ class oracleDistro(redhatDistro): """ Oracle Distro concrete class Put Oracle specific behavior here... """ def __init__(self): super(oracleDistro, self).__init__() ############################################################ # asianuxDistro ############################################################ class asianuxDistro(redhatDistro): """ Asianux Distro concrete class Put Asianux specific behavior here... """ def __init__(self): super(asianuxDistro, self).__init__() ############################################################ # CoreOSDistro ############################################################ class CoreOSDistro(AbstractDistro): """ CoreOS Distro concrete class Put CoreOS specific behavior here... """ CORE_UID = 500 def __init__(self): super(CoreOSDistro, self).__init__() self.requiredDeps += ["/usr/bin/systemctl"] self.agent_service_name = 'waagent' self.init_script_file = '/etc/systemd/system/waagent.service' self.fileBlackList.append("/etc/machine-id") self.dhcp_client_name = 'systemd-networkd' self.getpidcmd = 'pidof ' self.shadow_file_mode = 0o640 self.waagent_path = '/usr/share/oem/bin' self.python_path = '/usr/share/oem/python/bin' self.dhcp_enabled = True if 'PATH' in os.environ: os.environ['PATH'] = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: os.environ['PATH'] = self.python_path if 'PYTHONPATH' in os.environ: os.environ['PYTHONPATH'] = "{0}:{1}".format(os.environ['PYTHONPATH'], self.waagent_path) else: os.environ['PYTHONPATH'] = self.waagent_path def checkPackageInstalled(self, p): """ There is no package manager in CoreOS. Return 1 since it must be preinstalled. """ return 1 def checkDependencies(self): for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self, p): """ There is no package manager in CoreOS. Return 0 since it can't be updated via package. """ return 0 def startAgentService(self): return Run('systemctl start ' + self.agent_service_name) def stopAgentService(self): return Run('systemctl stop ' + self.agent_service_name) def restartSshService(self): """ SSH is socket activated on CoreOS. No need to restart it. """ return 0 def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 def RestartInterface(self, iface): Run("systemctl restart systemd-networkd") def CreateAccount(self, user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin and userentry[2] != self.CORE_UID: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd --create-home --password '*' " + user if expiration != None: command += " --expiredate " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: self.changePass(user, password) try: if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def startDHCP(self): Run("systemctl start " + self.dhcp_client_name, chk_err=False) def stopDHCP(self): Run("systemctl stop " + self.dhcp_client_name, chk_err=False) def translateCustomData(self, data): return base64.b64decode(data) def getConfigurationPath(self): return "/usr/share/oem/waagent.conf" ############################################################ # debianDistro ############################################################ debian_init_file = """\ #!/bin/sh ### BEGIN INIT INFO # Provides: AzureLinuxAgent # Required-Start: $network $syslog # Required-Stop: $network $syslog # Should-Start: $network $syslog # Should-Stop: $network $syslog # Default-Start: 2 3 4 5 # Default-Stop: 0 1 6 # Short-Description: AzureLinuxAgent # Description: AzureLinuxAgent ### END INIT INFO . /lib/lsb/init-functions OPTIONS="-daemon" WAZD_BIN=/usr/sbin/waagent WAZD_PID=/var/run/waagent.pid case "$1" in start) log_begin_msg "Starting AzureLinuxAgent..." pid=$( pidofproc $WAZD_BIN ) if [ -n "$pid" ] ; then log_begin_msg "Already running." log_end_msg 0 exit 0 fi start-stop-daemon --start --quiet --oknodo --background --exec $WAZD_BIN -- $OPTIONS log_end_msg $? ;; stop) log_begin_msg "Stopping AzureLinuxAgent..." start-stop-daemon --stop --quiet --oknodo --pidfile $WAZD_PID ret=$? rm -f $WAZD_PID log_end_msg $ret ;; force-reload) $0 restart ;; restart) $0 stop $0 start ;; status) status_of_proc $WAZD_BIN && exit 0 || exit $? ;; *) log_success_msg "Usage: /etc/init.d/waagent {start|stop|force-reload|restart|status}" exit 1 ;; esac exit 0 """ class debianDistro(AbstractDistro): """ debian Distro concrete class Put debian specific behavior here... """ def __init__(self): super(debianDistro, self).__init__() self.requiredDeps += ["/usr/sbin/update-rc.d"] self.init_file = debian_init_file self.agent_package_name = 'walinuxagent' self.dhcp_client_name = 'dhclient' self.getpidcmd = 'pidof ' self.shadow_file_mode = 0o640 def checkPackageInstalled(self, p): """ Check that the package is installed. Return 1 if installed, 0 if not installed. This method of using dpkg-query allows wildcards to be present in the package name. """ if not Run("dpkg-query -W -f='${Status}\n' '" + p + "' | grep ' installed' 2>&1", chk_err=False): return 1 else: return 0 def checkDependencies(self): """ Debian dependency check. python-pyasn1 is NOT needed. Return 1 unless all dependencies are satisfied. NOTE: using network*manager will catch either package name in Ubuntu or debian. """ if self.checkPackageInstalled('network*manager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self, p): if Run("apt-get update ; apt-get upgrade -us | grep " + p, chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) except OSError as e: ErrorWithPrefix('installAgentServiceScriptFiles', 'Exception: ' + str(e) + ' occured creating ' + self.init_script_file) return 1 return 0 def registerAgentService(self): if self.installAgentServiceScriptFiles() == 0: return Run('update-rc.d waagent defaults') else: return 1 def uninstallAgentService(self): return Run('update-rc.d -f ' + self.agent_service_name + ' remove') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 ############################################################ # KaliDistro - WIP # Functioning on Kali 1.1.0a so far ############################################################ class KaliDistro(debianDistro): """ Kali Distro concrete class Put Kali specific behavior here... """ def __init__(self): super(KaliDistro, self).__init__() ############################################################ # UbuntuDistro ############################################################ ubuntu_upstart_file = """\ #walinuxagent - start Azure agent description "walinuxagent" author "Ben Howard " start on (filesystem and started rsyslog) pre-start script WALINUXAGENT_ENABLED=1 [ -r /etc/default/walinuxagent ] && . /etc/default/walinuxagent if [ "$WALINUXAGENT_ENABLED" != "1" ]; then exit 1 fi if [ ! -x /usr/sbin/waagent ]; then exit 1 fi #Load the udf module modprobe -b udf end script exec /usr/sbin/waagent -daemon """ class UbuntuDistro(debianDistro): """ Ubuntu Distro concrete class Put Ubuntu specific behavior here... """ def __init__(self): super(UbuntuDistro, self).__init__() self.init_script_file = '/etc/init/waagent.conf' self.init_file = ubuntu_upstart_file self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log"] self.dhcp_client_name = None self.getpidcmd = 'pidof ' def registerAgentService(self): return self.installAgentServiceScriptFiles() def uninstallAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 os.remove('/etc/init/' + self.agent_service_name + '.conf') def unregisterAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return self.stopAgentService() return self.uninstallAgentService() def deprovisionWarnUser(self): """ Ubuntu specific warning string from Deprovision. """ print("WARNING! Nameserver configuration in /etc/resolvconf/resolv.conf.d/{tail,original} will be deleted.") def deprovisionDeleteFiles(self): """ Ubuntu uses resolv.conf by default, so removing /etc/resolv.conf will break resolvconf. Therefore, we check to see if resolvconf is in use, and if so, we remove the resolvconf artifacts. """ if os.path.realpath('/etc/resolv.conf') != '/run/resolvconf/resolv.conf': Log("resolvconf is not configured. Removing /etc/resolv.conf") self.fileBlackList.append('/etc/resolv.conf') else: Log("resolvconf is enabled; leaving /etc/resolv.conf intact") resolvConfD = '/etc/resolvconf/resolv.conf.d/' self.fileBlackList.extend([resolvConfD + 'tail', resolvConfD + 'original']) for f in os.listdir(LibDir) + self.fileBlackList: try: os.remove(f) except: pass return 0 def getDhcpClientName(self): if self.dhcp_client_name != None: return self.dhcp_client_name if DistInfo()[1] == '12.04': self.dhcp_client_name = 'dhclient3' else: self.dhcp_client_name = 'dhclient' return self.dhcp_client_name def waitForSshHostKey(self, path): """ Wait until the ssh host key is generated by cloud init. """ for retry in range(0, 10): if (os.path.isfile(path)): return True time.sleep(1) Error("Can't find host key: {0}".format(path)) return False ############################################################ # LinuxMintDistro ############################################################ class LinuxMintDistro(UbuntuDistro): """ LinuxMint Distro concrete class Put LinuxMint specific behavior here... """ def __init__(self): super(LinuxMintDistro, self).__init__() ############################################################ # fedoraDistro ############################################################ fedora_systemd_service = """\ [Unit] Description=Azure Linux Agent After=network.target After=sshd.service ConditionFileIsExecutable=/usr/sbin/waagent ConditionPathExists=/etc/waagent.conf [Service] Type=simple ExecStart=/usr/sbin/waagent -daemon [Install] WantedBy=multi-user.target """ class fedoraDistro(redhatDistro): """ FedoraDistro concrete class Put Fedora specific behavior here... """ def __init__(self): super(fedoraDistro, self).__init__() self.service_cmd = '/usr/bin/systemctl' self.hostname_file_path = '/etc/hostname' self.init_script_file = '/usr/lib/systemd/system/' + self.agent_service_name + '.service' self.init_file = fedora_systemd_service self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX=' def publishHostname(self, name): SetFileContents(self.hostname_file_path, name + '\n') ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join( filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n')))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o644) return Run(self.service_cmd + ' daemon-reload') def registerAgentService(self): self.installAgentServiceScriptFiles() return Run(self.service_cmd + ' enable ' + self.agent_service_name) def uninstallAgentService(self): """ Call service subsystem to remove waagent script. """ return Run(self.service_cmd + ' disable ' + self.agent_service_name) def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' start ' + self.agent_service_name) def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' stop ' + self.agent_service_name, False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_restart_option + " " + self.ssh_service_name retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def checkPackageInstalled(self, p): """ Query package database for prescence of an installed package. """ import rpm ts = rpm.TransactionSet() rpms = ts.dbMatch(rpm.RPMTAG_PROVIDES, p) return bool(len(rpms) > 0) def deleteRootPassword(self): return Run("/sbin/usermod root -p '!!'") def packagedInstall(self, buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot + '/etc'): os.mkdir(buildroot + '/etc') SetFileContents(buildroot + '/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot + '/etc/logrotate.d'): os.mkdir(buildroot + '/etc/logrotate.d') SetFileContents(buildroot + '/etc/logrotate.d/WALinuxAgent', WaagentLogrotate) self.init_script_file = buildroot + self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def CreateAccount(self, user, password, expiration, thumbprint): super(fedoraDistro, self).CreateAccount(user, password, expiration, thumbprint) Run('/sbin/usermod ' + user + ' -G wheel') def DeleteAccount(self, user): Run('/sbin/usermod ' + user + ' -G ""') super(fedoraDistro, self).DeleteAccount(user) ############################################################ # FreeBSD ############################################################ FreeBSDWaagentConf = """\ # # Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ufs2 # ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ bsd_init_file = """\ #! /bin/sh # PROVIDE: waagent # REQUIRE: DAEMON cleanvar sshd # BEFORE: LOGIN # KEYWORD: nojail . /etc/rc.subr export PATH=$PATH:/usr/local/bin name="waagent" rcvar="waagent_enable" command="/usr/sbin/${name}" command_interpreter="/usr/local/bin/python" waagent_flags=" daemon &" pidfile="/var/run/waagent.pid" load_rc_config $name run_rc_command "$1" """ bsd_activate_resource_disk_txt = """\ #!/usr/bin/env python import os import sys import imp # waagent has no '.py' therefore create waagent module import manually. __name__='setupmain' #prevent waagent.__main__ from executing waagent=imp.load_source('waagent','/tmp/waagent') waagent.LoggerInit('/var/log/waagent.log','/dev/console') from waagent import RunGetOutput,Run Config=waagent.ConfigurationProvider(None) format = Config.get("ResourceDisk.Format") if format == None or format.lower().startswith("n"): sys.exit(0) device_base = 'da1' device = "/dev/" + device_base for entry in RunGetOutput("mount")[1].split(): if entry.startswith(device + "s1"): waagent.Log("ActivateResourceDisk: " + device + "s1 is already mounted.") sys.exit(0) mountpoint = Config.get("ResourceDisk.MountPoint") if mountpoint == None: mountpoint = "/mnt/resource" waagent.CreateDir(mountpoint, "root", 0o755) fs = Config.get("ResourceDisk.Filesystem") if waagent.FreeBSDDistro().mediaHasFilesystem(device) == False : Run("newfs " + device + "s1") if Run("mount " + device + "s1 " + mountpoint): waagent.Error("ActivateResourceDisk: Failed to mount resource disk (" + device + "s1).") sys.exit(0) waagent.Log("Resource disk (" + device + "s1) is mounted at " + mountpoint + " with fstype " + fs) waagent.SetFileContents(os.path.join(mountpoint,waagent.README_FILENAME), waagent.README_FILECONTENT) swap = Config.get("ResourceDisk.EnableSwap") if swap == None or swap.lower().startswith("n"): sys.exit(0) sizeKB = int(Config.get("ResourceDisk.SwapSizeMB")) * 1024 if os.path.isfile(mountpoint + "/swapfile") and os.path.getsize(mountpoint + "/swapfile") != (sizeKB * 1024): os.remove(mountpoint + "/swapfile") if not os.path.isfile(mountpoint + "/swapfile"): Run("dd if=/dev/zero of=" + mountpoint + "/swapfile bs=1024 count=" + str(sizeKB)) if Run("mdconfig -a -t vnode -f " + mountpoint + "/swapfile -u 0"): waagent.Error("ActivateResourceDisk: Configuring swap - Failed to create md0") if not Run("swapon /dev/md0"): waagent.Log("Enabled " + str(sizeKB) + " KB of swap at " + mountpoint + "/swapfile") else: waagent.Error("ActivateResourceDisk: Failed to activate swap at " + mountpoint + "/swapfile") """ class FreeBSDDistro(AbstractDistro): """ """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ super(FreeBSDDistro, self).__init__() self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux = False self.ssh_service_name = 'sshd' self.ssh_config_file = '/etc/ssh/sshd_config' self.hostname_file_path = '/etc/hostname' self.dhcp_client_name = 'dhclient' self.requiredDeps = ['route', 'shutdown', 'ssh-keygen', 'pw' , 'openssl', 'fdisk', 'sed', 'grep', 'sudo'] self.init_script_file = '/etc/rc.d/waagent' self.init_file = bsd_init_file self.agent_package_name = 'WALinuxAgent' self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log", '/etc/resolv.conf'] self.agent_files_to_uninstall = ["/etc/waagent.conf"] self.grubKernelBootOptionsFile = '/boot/loader.conf' self.grubKernelBootOptionsLine = '' self.getpidcmd = 'pgrep -n' self.mount_dvd_cmd = 'dd bs=2048 count=33 skip=295 if=' # custom data max len is 64k self.sudoers_dir_base = '/usr/local/etc' self.waagent_conf_file = FreeBSDWaagentConf def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o777) AppendFileContents("/etc/rc.conf", "waagent_enable='YES'\n") return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run("services_mkdb " + self.init_script_file) def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 def deleteRootPassword(self): """ BSD root password removal. """ filepath = "/etc/master.passwd" ReplaceStringInFile(filepath, r'root:.*?:', 'root::') # ReplaceFileContentsAtomic(filepath,"root:*LOCK*:14600::::::\n" # + "\n".join(filter(lambda a: not a.startswith("root:"),GetFileContents(filepath).split('\n')))) os.chmod(filepath, self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath, 'system_u:object_r:shadow_t:s0') RunGetOutput("pwd_mkdb -u root /etc/master.passwd") Log("Root password deleted.") return 0 def changePass(self, user, password): return RunSendStdin("pw usermod " + user + " -h 0 ", password, log_cmd=False, use_shell=False) def load_ata_piix(self): return 0 def unload_ata_piix(self): return 0 def checkDependencies(self): """ FreeBSD dependency check. Return 1 unless all dependencies are satisfied. """ for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self, buildroot): pass def GetInterfaceName(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() return iface def RestartInterface(self, iface): Run("service netif restart") def GetIpv4Address(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() return inet def GetMacAddress(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() l = mac.split(':') r = [] for i in l: r.append(string.atoi(i, 16)) return r def GetFreeBSDEthernetInfo(self): """ There is no SIOCGIFCONF on freeBSD - just parse ifconfig. Returns strings: iface, inet4_addr, and mac or 'None,None,None' if unable to parse. We will sleep and retry as the network must be up. """ code, output = RunGetOutput("ifconfig", chk_err=False) Log(output) retries = 10 cmd = 'ifconfig | grep -A2 -B2 ether | grep -B3 inet | grep -A4 UP ' code = 1 while code > 0: if code > 0 and retries == 0: Error("GetFreeBSDEthernetInfo - Failed to detect ethernet interface") return None, None, None code, output = RunGetOutput(cmd, chk_err=False) retries -= 1 if code > 0 and retries > 0: Log("GetFreeBSDEthernetInfo - Error: retry ethernet detection " + str(retries)) if retries == 9: c, o = RunGetOutput("ifconfig | grep -A1 -B2 ether", chk_err=False) if c == 0: t = o.replace('\n', ' ') t = t.split() i = t[0][:-1] Log(RunGetOutput('id')[1]) Run('dhclient ' + i) time.sleep(10) j = output.replace('\n', ' ') j = j.split() iface = j[0][:-1] for i in range(len(j)): if j[i] == 'inet': inet = j[i + 1] elif j[i] == 'ether': mac = j[i + 1] return iface, inet, mac def CreateAccount(self, user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "pw useradd " + user + " -m" if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: self.changePass(user, password) try: # for older distros create sudoers.d if not os.path.isdir(MyDistro.sudoers_dir_base + '/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir(MyDistro.sudoers_dir_base + '/sudoers.d') # add the include of sudoers.d to the /etc/sudoers SetFileContents(MyDistro.sudoers_dir_base + '/sudoers', GetFileContents( MyDistro.sudoers_dir_base + '/sudoers') + '\n#includedir ' + MyDistro.sudoers_dir_base + '/sudoers.d\n') if password == None: SetFileContents(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(self, user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") # Delete utmp to prevent error if we are the 'user' deleted pid = subprocess.Popen(['rmuser', '-y', user], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE).pid try: os.remove(MyDistro.sudoers_dir_base + "/sudoers.d/waagent") except: pass return def ActivateResourceDiskNoThread(self): """ Format, mount, and if specified in the configuration set resource disk as swap. """ global DiskActivated Run('cp /usr/sbin/waagent /tmp/') SetFileContents('/tmp/bsd_activate_resource_disk.py', bsd_activate_resource_disk_txt) Run('chmod +x /tmp/bsd_activate_resource_disk.py') pid = subprocess.Popen(["/tmp/bsd_activate_resource_disk.py", ""]).pid Log("Spawning bsd_activate_resource_disk.py") DiskActivated = True return def Install(self): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0o755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a)) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", self.waagent_conf_file) if os.path.exists('/usr/local/etc/logrotate.d/'): SetFileContents("/usr/local/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split( '\n'))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") # ApplyVNUMAWorkaround() return 0 def mediaHasFilesystem(self, dsk): if Run('LC_ALL=C fdisk -p ' + dsk + ' | grep "invalid fdisk partition table found" ', False): return False return True def mountDVD(self, dvd, location): # At this point we cannot read a joliet option udf DVD in freebsd10 - so we 'dd' it into our location retcode, out = RunGetOutput(self.mount_dvd_cmd + dvd + ' of=' + location + '/ovf-env.xml') if retcode != 0: return retcode, out ovfxml = (GetFileContents(location + "/ovf-env.xml", asbin=False)) if ord(ovfxml[0]) > 128 and ord(ovfxml[1]) > 128 and ord(ovfxml[2]) > 128: ovfxml = ovfxml[ 3:] # BOM is not stripped. First three bytes are > 128 and not unicode chars so we ignore them. ovfxml = ovfxml.strip(chr(0x00)) ovfxml = "".join(filter(lambda x: ord(x) < 128, ovfxml)) ovfxml = re.sub(r'.*\Z', '', ovfxml, 0, re.DOTALL) ovfxml += '' SetFileContents(location + "/ovf-env.xml", ovfxml) return retcode, out def GetHome(self): return '/home' def initScsiDiskTimeout(self): """ Set the SCSI disk timeout by updating the kernal config """ timeout = Config.get("OS.RootDeviceScsiTimeout") if timeout: Run("sysctl kern.cam.da.default_timeout=" + timeout) def setScsiDiskTimeout(self): return def setBlockDeviceTimeout(self, device, timeout): return def getProcessorCores(self): return int(RunGetOutput("sysctl hw.ncpu | awk '{print $2}'")[1]) def getTotalMemory(self): return int(RunGetOutput("sysctl hw.realmem | awk '{print $2}'")[1]) / 1024 def setDefaultGateway(self, gateway): Run("/sbin/route add default " + gateway, chk_err=False) def routeAdd(self, net, mask, gateway): Run("/sbin/route add -net " + net + " " + mask + " " + gateway, chk_err=False) ############################################################ # END DISTRO CLASS DEFS ############################################################ # This lets us index into a string or an array of integers transparently. def Ord(a): """ Allows indexing into a string or an array of integers transparently. Generic utility function. """ if type(a) == type("a"): a = ord(a) return a def IsLinux(): """ Returns True if platform is Linux. Generic utility function. """ return (platform.uname()[0] == "Linux") def GetLastPathElement(path): """ Similar to basename. Generic utility function. """ return path.rsplit('/', 1)[1] def GetFileContents(filepath, asbin=False): """ Read and return contents of 'filepath'. """ mode = 'r' if asbin: mode += 'b' c = None try: with open(filepath, mode) as F: c = F.read() except IOError as e: ErrorWithPrefix('GetFileContents', 'Reading from file ' + filepath + ' Exception is ' + str(e)) return None return c def SetFileContents(filepath, contents): """ Write 'contents' to 'filepath'. """ if type(contents) == str: contents = contents.encode('latin-1', 'ignore') try: with open(filepath, "wb+") as F: F.write(contents) except IOError as e: ErrorWithPrefix('SetFileContents', 'Writing to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def AppendFileContents(filepath, contents): """ Append 'contents' to 'filepath'. """ if type(contents) == str: if sys.version_info[0] == 3: contents = contents.encode('latin-1').decode('latin-1') elif sys.version_info[0] == 2: contents = contents.encode('latin-1') try: with open(filepath, "a+") as F: F.write(contents) except IOError as e: ErrorWithPrefix('AppendFileContents', 'Appending to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def ReplaceFileContentsAtomic(filepath, contents): """ Write 'contents' to 'filepath' by creating a temp file, and replacing original. """ handle, temp = tempfile.mkstemp(dir=os.path.dirname(filepath)) if type(contents) == str: contents = contents.encode('latin-1') try: os.write(handle, contents) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Writing to file ' + filepath + ' Exception is ' + str(e)) return None finally: os.close(handle) try: os.rename(temp, filepath) return None except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Renaming ' + temp + ' to ' + filepath + ' Exception is ' + str(e)) try: os.remove(filepath) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) try: os.rename(temp, filepath) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) return 1 return 0 def GetLineStartingWith(prefix, filepath): """ Return line from 'filepath' if the line startswith 'prefix' """ for line in GetFileContents(filepath).split('\n'): if line.startswith(prefix): return line return None def Run(cmd, chk_err=True): """ Calls RunGetOutput on 'cmd', returning only the return code. If chk_err=True then errors will be reported in the log. If chk_err=False then errors will be suppressed from the log. """ retcode, out = RunGetOutput(cmd, chk_err) return retcode def RunGetOutput(cmd, chk_err=True, log_cmd=True): """ Wrapper for subprocess.check_output. Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: LogIfVerbose(cmd) try: output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) except subprocess.CalledProcessError as e: if chk_err and log_cmd: Error('CalledProcessError. Error Code is ' + str(e.returncode)) Error('CalledProcessError. Command string was ' + e.cmd) Error('CalledProcessError. Command result was ' + (e.output[:-1]).decode('latin-1')) return e.returncode, e.output.decode('latin-1') return 0, output.decode('latin-1') def RunSendStdin(cmd, input, chk_err=True, log_cmd=True, use_shell=True): """ Wrapper for subprocess.Popen. Execute 'cmd', sending 'input' to STDIN of 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: LogIfVerbose(str(cmd) + str(input)) try: me = subprocess.Popen([cmd], shell=use_shell, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) output = me.communicate(input) except OSError as e: if chk_err and log_cmd: Error('CalledProcessError. Error Code is ' + str(me.returncode)) Error('CalledProcessError. Command string was ' + cmd) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return 1, output[0].decode('latin-1') if me.returncode != 0 and chk_err is True and log_cmd: Error('CalledProcessError. Error Code is ' + str(me.returncode)) Error('CalledProcessError. Command string was ' + cmd) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return me.returncode, output[0].decode('latin-1') def GetNodeTextData(a): """ Filter non-text nodes from DOM tree """ for b in a.childNodes: if b.nodeType == b.TEXT_NODE: return b.data def GetHome(): """ Attempt to guess the $HOME location. Return the path string. """ home = None try: home = GetLineStartingWith("HOME", "/etc/default/useradd").split('=')[1].strip() except: pass if (home == None) or (home.startswith("/") == False): home = "/home" return home def ChangeOwner(filepath, user): """ Lookup user. Attempt chown 'filepath' to 'user'. """ p = None try: p = pwd.getpwnam(user) except: pass if p != None: os.chown(filepath, p[2], p[3]) def CreateDir(dirpath, user, mode): """ Attempt os.makedirs, catch all exceptions. Call ChangeOwner afterwards. """ try: os.makedirs(dirpath, mode) except: pass ChangeOwner(dirpath, user) def CreateAccount(user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd -m " + user if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: MyDistro.changePass(user, password) try: # for older distros create sudoers.d if not os.path.isdir('/etc/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir('/etc/sudoers.d/') # add the include of sudoers.d to the /etc/sudoers SetFileContents('/etc/sudoers', GetFileContents('/etc/sudoers') + '\n#includedir /etc/sudoers.d\n') if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") # Delete utmp to prevent error if we are the 'user' deleted Run("userdel -f -r " + user) try: os.remove("/etc/sudoers.d/waagent") except: pass return def IsInRangeInclusive(a, low, high): """ Return True if 'a' in 'low' <= a >= 'high' """ return (a >= low and a <= high) def IsPrintable(ch): """ Return True if character is displayable. """ return IsInRangeInclusive(ch, Ord('A'), Ord('Z')) or IsInRangeInclusive(ch, Ord('a'), Ord('z')) or IsInRangeInclusive(ch, Ord('0'), Ord('9')) def HexDump(buffer, size): """ Return Hex formated dump of a 'buffer' of 'size'. """ if size < 0: size = len(buffer) result = "" for i in range(0, size): if (i % 16) == 0: result += "%06X: " % i byte = buffer[i] if type(byte) == str: byte = ord(byte.decode('latin1')) result += "%02X " % byte if (i & 15) == 7: result += " " if ((i + 1) % 16) == 0 or (i + 1) == size: j = i while ((j + 1) % 16) != 0: result += " " if (j & 7) == 7: result += " " j += 1 result += " " for j in range(i - (i % 16), i + 1): byte = buffer[j] if type(byte) == str: byte = ord(byte.decode('latin1')) k = '.' if IsPrintable(byte): k = chr(byte) result += k if (i + 1) != size: result += "\n" return result def SimpleLog(file_path, message): if not file_path or len(message) < 1: return t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) lines = re.sub(re.compile(r'^(.)', re.MULTILINE), t + r'\1', message) with open(file_path, "a") as F: lines = filter(lambda x: x in string.printable, lines) F.write(lines.encode('ascii', 'ignore') + "\n") class Logger(object): """ The Agent's logging assumptions are: For Log, and LogWithPrefix all messages are logged to the self.file_path and to the self.con_path. Setting either path parameter to None skips that log. If Verbose is enabled, messages calling the LogIfVerbose method will be logged to file_path yet not to con_path. Error and Warn messages are normal log messages with the 'ERROR:' or 'WARNING:' prefix added. """ def __init__(self, filepath, conpath, verbose=False): """ Construct an instance of Logger. """ self.file_path = filepath self.con_path = conpath self.verbose = verbose def ThrottleLog(self, counter): """ Log everything up to 10, every 10 up to 100, then every 100. """ return (counter < 10) or ((counter < 100) and ((counter % 10) == 0)) or ((counter % 100) == 0) def WriteToFile(self, message): """ Write 'message' to logfile. """ if self.file_path: try: with open(self.file_path, "a") as F: message = filter(lambda x: x in string.printable, message) # encoding works different for between interpreter version, we are keeping separate implementation # to ensure backward compatibility if sys.version_info[0] == 3: message = ''.join(list(message)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: message = message.encode('ascii', 'ignore') F.write(message + "\n") except IOError as e: pass def WriteToConsole(self, message): """ Write 'message' to /dev/console. This supports serial port logging if the /dev/console is redirected to ttys0 in kernel boot options. """ if self.con_path: try: with open(self.con_path, "w") as C: message = filter(lambda x: x in string.printable, message) # encoding works different for between interpreter version, we are keeping separate implementation # to ensure backward compatibility if sys.version_info[0] == 3: message = ''.join(list(message)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: message = message.encode('ascii', 'ignore') C.write(message + "\n") except IOError as e: pass def Log(self, message): """ Standard Log function. Logs to self.file_path, and con_path """ self.LogWithPrefix("", message) def LogToConsole(self, message): """ Logs message to console by pre-pending each line of 'message' with current time. """ log_prefix = self._get_log_prefix("") for line in message.split('\n'): line = log_prefix + line self.WriteToConsole(line) def LogToFile(self, message): """ Logs message to file by pre-pending each line of 'message' with current time. """ log_prefix = self._get_log_prefix("") for line in message.split('\n'): line = log_prefix + line self.WriteToFile(line) def NoLog(self, message): """ Don't Log. """ pass def LogIfVerbose(self, message): """ Only log 'message' if global Verbose is True. """ self.LogWithPrefixIfVerbose('', message) def LogWithPrefix(self, prefix, message): """ Prefix each line of 'message' with current time+'prefix'. """ log_prefix = self._get_log_prefix(prefix) for line in message.split('\n'): line = log_prefix + line self.WriteToFile(line) self.WriteToConsole(line) def LogWithPrefixIfVerbose(self, prefix, message): """ Only log 'message' if global Verbose is True. Prefix each line of 'message' with current time+'prefix'. """ if self.verbose == True: log_prefix = self._get_log_prefix(prefix) for line in message.split('\n'): line = log_prefix + line self.WriteToFile(line) self.WriteToConsole(line) def Warn(self, message): """ Prepend the text "WARNING:" for each line in 'message'. """ self.LogWithPrefix("WARNING:", message) def Error(self, message): """ Call ErrorWithPrefix(message). """ ErrorWithPrefix("", message) def ErrorWithPrefix(self, prefix, message): """ Prepend the text "ERROR:" to the prefix for each line in 'message'. Errors written to logfile, and /dev/console """ self.LogWithPrefix("ERROR:", message) def _get_log_prefix(self, prefix): """ Generates the log prefix with timestamp+'prefix'. """ t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) return t + prefix def LoggerInit(log_file_path, log_con_path, verbose=False): """ Create log object and export its methods to global scope. """ global Log, LogToConsole, LogToFile, LogWithPrefix, LogIfVerbose, LogWithPrefixIfVerbose, Error, ErrorWithPrefix, Warn, NoLog, ThrottleLog, myLogger l = Logger(log_file_path, log_con_path, verbose) Log, LogToConsole, LogToFile, LogWithPrefix, LogIfVerbose, LogWithPrefixIfVerbose, Error, ErrorWithPrefix, Warn, NoLog, ThrottleLog, myLogger = l.Log, l.LogToConsole, l.LogToFile, l.LogWithPrefix, l.LogIfVerbose, l.LogWithPrefixIfVerbose, l.Error, l.ErrorWithPrefix, l.Warn, l.NoLog, l.ThrottleLog, l def Linux_ioctl_GetInterfaceMac(ifname): """ Return the mac-address bound to the socket. """ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) info = fcntl.ioctl(s.fileno(), 0x8927, struct.pack('256s', (ifname[:15] + ('\0' * 241)).encode('latin-1'))) return ''.join(['%02X' % Ord(char) for char in info[18:24]]) def GetFirstActiveNetworkInterfaceNonLoopback(): """ Return the interface name, and ip addr of the first active non-loopback interface. """ iface = '' expected = 16 # how many devices should I expect... is_64bits = sys.maxsize > 2 ** 32 struct_size = 40 if is_64bits else 32 # for 64bit the size is 40 bytes, for 32bits it is 32 bytes. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) buff = array.array('B', b'\0' * (expected * struct_size)) retsize = (struct.unpack('iL', fcntl.ioctl(s.fileno(), 0x8912, struct.pack('iL', expected * struct_size, buff.buffer_info()[0]))))[0] if retsize == (expected * struct_size): Warn('SIOCGIFCONF returned more than ' + str(expected) + ' up network interfaces.') s = buff.tostring() preferred_nic = Config.get("Network.Interface") for i in range(0, struct_size * expected, struct_size): iface = s[i:i + 16].split(b'\0', 1)[0] if iface == b'lo': continue elif preferred_nic is None: break elif iface == preferred_nic: break return iface.decode('latin-1'), socket.inet_ntoa(s[i + 20:i + 24]) def GetIpv4Address(): """ Return the ip of the first active non-loopback interface. """ iface, addr = GetFirstActiveNetworkInterfaceNonLoopback() return addr def HexStringToByteArray(a): """ Return hex string packed into a binary struct. """ b = b"" for c in range(0, len(a) // 2): b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16)) return b def GetMacAddress(): """ Convienience function, returns mac addr bound to first non-loobback interface. """ ifname = '' while len(ifname) < 2: ifname = GetFirstActiveNetworkInterfaceNonLoopback()[0] a = Linux_ioctl_GetInterfaceMac(ifname) return HexStringToByteArray(a) def DeviceForIdePort(n): """ Return device name attached to ide port 'n'. """ if n > 3: return None g0 = "00000000" if n > 1: g0 = "00000001" n = n - 2 device = None path = "/sys/bus/vmbus/devices/" for vmbus in os.listdir(path): guid = GetFileContents(path + vmbus + "/device_id").lstrip('{').split('-') if guid[0] == g0 and guid[1] == "000" + str(n): for root, dirs, files in os.walk(path + vmbus): if root.endswith("/block"): device = dirs[0] break else: # older distros for d in dirs: if ':' in d and "block" == d.split(':')[0]: device = d.split(':')[1] break break return device class HttpResourceGoneError(Exception): pass def DoInstallRHUIRPM(): """ Install RHUI RPM according to VM region """ rhuiRPMinstalled = os.path.exists(LibDir + "/rhuirpminstalled") if rhuiRPMinstalled: return else: SetFileContents(LibDir + "/rhuirpminstalled", "") Log("Begin to install RHUI RPM") cmd = r"grep '' /var/lib/waagent/ExtensionsConfig* --no-filename | sed 's///g' | sed 's/<\/Location>//g' | sed 's/ //g' | tr 'A-Z' 'a-z' | uniq" retcode, out = RunGetOutput(cmd, True) region = out.rstrip("\n") # try a few times at most to get the region info retry = 0 for i in range(0, 8): if (region != ""): break Log("region info is empty, now wait 15 seconds...") time.sleep(15) retcode, out = RunGetOutput(cmd, True) region = out.rstrip("\n") if region == "": Log("could not detect region info, now use the default region: eastus2") region = "eastus2" scriptFilePath = "/tmp/install-rhui-rpm.sh" if not os.path.exists(scriptFilePath): Error(scriptFilePath + " does not exist, now quit RHUI RPM installation."); return # chmod a+x script file os.chmod(scriptFilePath, 0o100) Log("begin to run " + scriptFilePath) # execute the downloaded script file retcode, out = RunGetOutput(scriptFilePath, True) if retcode != 0: Error("execute script " + scriptFilePath + " failed, return code: " + str( retcode) + ", now exit RHUI RPM installation."); return Log("install RHUI RPM completed") class Util(object): """ Http communication class. Base of GoalState, and Agent classes. """ RetryWaitingInterval = 10 def __init__(self): self.Endpoint = None def _ParseUrl(self, url): secure = False host = self.Endpoint path = url port = None # "http[s]://hostname[:port][/]" if url.startswith("http://"): url = url[7:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" elif url.startswith("https://"): secure = True url = url[8:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" if host is None: raise ValueError("Host is invalid:{0}".format(url)) if (":" in host): pos = host.rfind(":") port = int(host[pos + 1:]) host = host[0:pos] return host, port, secure, path def GetHttpProxy(self, secure): """ Get http_proxy and https_proxy from environment variables. Username and password is not supported now. """ host = Config.get("HttpProxy.Host") port = Config.get("HttpProxy.Port") return (host, port) def _HttpRequest(self, method, host, path, port=None, data=None, secure=False, headers=None, proxyHost=None, proxyPort=None): resp = None conn = None try: if secure: port = 443 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httpclient.HTTPSConnection(proxyHost, proxyPort, timeout=10) conn.set_tunnel(host, port) # If proxy is used, full url is needed. path = "https://{0}:{1}{2}".format(host, port, path) else: conn = httpclient.HTTPSConnection(host, port, timeout=10) else: port = 80 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httpclient.HTTPConnection(proxyHost, proxyPort, timeout=10) # If proxy is used, full url is needed. path = "http://{0}:{1}{2}".format(host, port, path) else: conn = httpclient.HTTPConnection(host, port, timeout=10) if headers == None: conn.request(method, path, data) else: conn.request(method, path, data, headers) resp = conn.getresponse() except httpclient.HTTPException as e: Error('HTTPException {0}, args:{1}'.format(e, repr(e.args))) except IOError as e: Error('Socket IOError {0}, args:{1}'.format(e, repr(e.args))) return resp def HttpRequest(self, method, url, data=None, headers=None, maxRetry=3, chkProxy=False): """ Sending http request to server On error, sleep 10 and maxRetry times. Return the output buffer or None. """ LogIfVerbose("HTTP Req: {0} {1}".format(method, url)) LogIfVerbose("HTTP Req: Data={0}".format(data)) LogIfVerbose("HTTP Req: Header={0}".format(headers)) try: host, port, secure, path = self._ParseUrl(url) except ValueError as e: Error("Failed to parse url:{0}".format(url)) return None # Check proxy proxyHost, proxyPort = (None, None) if chkProxy: proxyHost, proxyPort = self.GetHttpProxy(secure) # If httplib/httpclient module is not built with ssl support. Fallback to http if secure and not hasattr(httpclient, "HTTPSConnection"): Warn("httplib/httpclient is not built with ssl support") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) # If httplib/httpclient module doesn't support https tunnelling. Fallback to http if secure and \ proxyHost is not None and \ proxyPort is not None and \ not hasattr(httpclient.HTTPSConnection, "set_tunnel"): Warn("httplib/httpclient doesn't support https tunnelling(new in python 2.7)") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) for retry in range(0, maxRetry): if resp is not None and \ (resp.status == httpclient.OK or \ resp.status == httpclient.CREATED or \ resp.status == httpclient.ACCEPTED): return resp; if resp is not None and resp.status == httpclient.GONE: raise HttpResourceGoneError("Http resource gone.") Error("Retry={0}".format(retry)) Error("HTTP Req: {0} {1}".format(method, url)) Error("HTTP Req: Data={0}".format(data)) Error("HTTP Req: Header={0}".format(headers)) if resp is None: Error("HTTP Err: response is empty.".format(retry)) else: Error("HTTP Err: Status={0}".format(resp.status)) Error("HTTP Err: Reason={0}".format(resp.reason)) Error("HTTP Err: Header={0}".format(resp.getheaders())) Error("HTTP Err: Body={0}".format(resp.read())) time.sleep(self.__class__.RetryWaitingInterval) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) return None def HttpGet(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("GET", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpHead(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("HEAD", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPost(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("POST", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPut(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("PUT", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpDelete(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("DELETE", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpGetWithoutHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url'. """ resp = self.HttpGet(url, headers=None, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpGetWithHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url' with x-ms-agent-name and x-ms-version headers. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpSecureGetWithHeaders(self, url, transportCert, maxRetry=3, chkProxy=False): """ Return output of get using ssl cert. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion, "x-ms-cipher-name": "DES_EDE3_CBC", "x-ms-guest-agent-public-x509-cert": transportCert }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpPostWithHeaders(self, url, data, maxRetry=3, chkProxy=False): headers = { "x-ms-agent-name": GuestAgentName, "Content-Type": "text/xml; charset=utf-8", "x-ms-version": ProtocolVersion } try: return self.HttpPost(url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) except HttpResourceGoneError as e: Error("Failed to post: {0} {1}".format(url, e)) return None __StorageVersion = "2014-02-14" def GetBlobType(url): restutil = Util() # Check blob type LogIfVerbose("Check blob type.") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) blobPropResp = restutil.HttpHead(url, { "x-ms-date": timestamp, 'x-ms-version': __StorageVersion }, chkProxy=True); blobType = None if blobPropResp is None: Error("Can't get status blob type.") return None blobType = blobPropResp.getheader("x-ms-blob-type") LogIfVerbose("Blob type={0}".format(blobType)) return blobType def PutBlockBlob(url, data): restutil = Util() LogIfVerbose("Upload block blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) ret = restutil.HttpPut(url, data, { "x-ms-date": timestamp, "x-ms-blob-type": "BlockBlob", "Content-Length": str(len(data)), "x-ms-version": __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to upload block blob for status.") return -1 return 0 def PutPageBlob(url, data): restutil = Util() LogIfVerbose("Replace old page blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) # Align to 512 bytes pageBlobSize = ((len(data) + 511) / 512) * 512 ret = restutil.HttpPut(url, "", { "x-ms-date": timestamp, "x-ms-blob-type": "PageBlob", "Content-Length": "0", "x-ms-blob-content-length": str(pageBlobSize), "x-ms-version": __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to clean up page blob for status") return -1 if url.index('?') < 0: url = "{0}?comp=page".format(url) else: url = "{0}&comp=page".format(url) LogIfVerbose("Upload page blob") pageMax = 4 * 1024 * 1024 # Max page size: 4MB start = 0 end = 0 while end < len(data): end = min(len(data), start + pageMax) contentSize = end - start # Align to 512 bytes pageEnd = ((end + 511) / 512) * 512 bufSize = pageEnd - start buf = bytearray(bufSize) buf[0: contentSize] = data[start: end] ret = restutil.HttpPut(url, buffer(buf), { "x-ms-date": timestamp, "x-ms-range": "bytes={0}-{1}".format(start, pageEnd - 1), "x-ms-page-write": "update", "x-ms-version": __StorageVersion, "Content-Length": str(pageEnd - start) }, chkProxy=True) if ret is None: Error("Failed to upload page blob for status") return -1 start = end return 0 def UploadStatusBlob(url, data): LogIfVerbose("Upload status blob") LogIfVerbose("Status={0}".format(data)) blobType = GetBlobType(url) if blobType == "BlockBlob": return PutBlockBlob(url, data) elif blobType == "PageBlob": return PutPageBlob(url, data) else: Error("Unknown blob type: {0}".format(blobType)) return -1 class TCPHandler(): """ Callback object for LoadBalancerProbeServer. Recv and send LB probe messages. """ def __init__(self, lb_probe): super(TCPHandler, self).__init__() self.lb_probe = lb_probe def GetHttpDateTimeNow(self): """ Return formatted gmtime "Date: Fri, 25 Mar 2011 04:53:10 GMT" """ return time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) def handle(self): """ Log LB probe messages, read the socket buffer, send LB probe response back to server. """ self.lb_probe.ProbeCounter = (self.lb_probe.ProbeCounter + 1) % 1000000 log = [NoLog, LogIfVerbose][ThrottleLog(self.lb_probe.ProbeCounter)] strCounter = str(self.lb_probe.ProbeCounter) if self.lb_probe.ProbeCounter == 1: Log("Receiving LB probes.") log("Received LB probe # " + strCounter) self.request.recv(1024) self.request.send( "HTTP/1.1 200 OK\r\nContent-Length: 2\r\nContent-Type: text/html\r\nDate: " + self.GetHttpDateTimeNow() + "\r\n\r\nOK") class LoadBalancerProbeServer(object): """ Threaded object to receive and send LB probe messages. Load Balancer messages but be recv'd by the load balancing server, or this node may be shut-down. """ def __init__(self, port): pass def shutdown(self): pass def get_ip(self): return None class ConfigurationProvider(object): """ Parse amd store key:values in waagent.conf """ def __init__(self, walaConfigFile): self.values = dict() if walaConfigFile is None: walaConfigFile = MyDistro.getConfigurationPath() if os.path.isfile(walaConfigFile) == False: raise Exception("Missing configuration in {0}".format(walaConfigFile)) try: for line in GetFileContents(walaConfigFile).split('\n'): if not line.startswith("#") and "=" in line: parts = line.split()[0].split('=') value = parts[1].strip("\" ") if value != "None": self.values[parts[0]] = value else: self.values[parts[0]] = None except: Error("Unable to parse {0}".format(walaConfigFile)) raise return def get(self, key): return self.values.get(key) def yes(self, key): configValue = self.get(key) if (configValue is not None and configValue.lower().startswith("y")): return True else: return False def no(self, key): configValue = self.get(key) if (configValue is not None and configValue.lower().startswith("n")): return True else: return False class EnvMonitor(object): """ Montor changes to dhcp and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ def __init__(self): self.shutdown = False self.HostName = socket.gethostname() self.server_thread = threading.Thread(target=self.monitor) self.server_thread.setDaemon(True) self.server_thread.start() self.published = False def monitor(self): """ Monitor dhcp client pid and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ publish = Config.get("Provisioning.MonitorHostName") dhcpcmd = MyDistro.getpidcmd + ' ' + MyDistro.getDhcpClientName() dhcppid = RunGetOutput(dhcpcmd)[1] while not self.shutdown: for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Log("EnvMonitor: Moved " + a + " -> " + LibDir) MyDistro.setScsiDiskTimeout() if publish != None and publish.lower().startswith("y"): try: if socket.gethostname() != self.HostName: Log("EnvMonitor: Detected host name change: " + self.HostName + " -> " + socket.gethostname()) self.HostName = socket.gethostname() WaAgent.UpdateAndPublishHostName(self.HostName) dhcppid = RunGetOutput(dhcpcmd)[1] self.published = True except: pass else: self.published = True pid = "" if not os.path.isdir("/proc/" + dhcppid.strip()): pid = RunGetOutput(dhcpcmd)[1] if pid != "" and pid != dhcppid: Log("EnvMonitor: Detected dhcp client restart. Restoring routing table.") WaAgent.RestoreRoutes() dhcppid = pid for child in Children: if child.poll() != None: Children.remove(child) time.sleep(5) def SetHostName(self, name): """ Generic call to MyDistro.setHostname(name). Complian to Log on error. """ if socket.gethostname() == name: self.published = True elif MyDistro.setHostname(name): Error("Error: SetHostName: Cannot set hostname to " + name) return ("Error: SetHostName: Cannot set hostname to " + name) def IsHostnamePublished(self): """ Return self.published """ return self.published def ShutdownService(self): """ Stop server comminucation and join the thread to main thread. """ self.shutdown = True self.server_thread.join() class Certificates(object): """ Object containing certificates of host and provisioned user. Parses and splits certificates into files. """ # # 2010-12-15 # 2 # Pkcs7BlobWithPfxContents # MIILTAY... # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset the Role, Incarnation """ self.Incarnation = None self.Role = None def Parse(self, xmlText): """ Parse multiple certificates into seperate files. """ self.reinitialize() SetFileContents("Certificates.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in ["CertificateFile", "Version", "Incarnation", "Format", "Data", ]: if not dom.getElementsByTagName(a): Error("Certificates.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "CertificateFile": Error("Certificates.Parse: root not CertificateFile") return None SetFileContents("Certificates.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"Certificates.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"Certificates.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + GetNodeTextData(dom.getElementsByTagName("Data")[0])) if Run( Openssl + " cms -decrypt -in Certificates.p7m -inkey TransportPrivate.pem -recip TransportCert.pem | " + Openssl + " pkcs12 -nodes -password pass: -out Certificates.pem"): Error("Certificates.Parse: Failed to extract certificates from CMS message.") return self # There may be multiple certificates in this package. Split them. file = open("Certificates.pem") pindex = 1 cindex = 1 output = open("temp.pem", "w") for line in file.readlines(): output.write(line) if re.match(r'[-]+END .*?(KEY|CERTIFICATE)[-]+$', line): output.close() if re.match(r'[-]+END .*?KEY[-]+$', line): os.rename("temp.pem", str(pindex) + ".prv") pindex += 1 else: os.rename("temp.pem", str(cindex) + ".crt") cindex += 1 output = open("temp.pem", "w") output.close() os.remove("temp.pem") keys = dict() index = 1 filename = str(index) + ".crt" while os.path.isfile(filename): thumbprint = \ (RunGetOutput(Openssl + " x509 -in " + filename + " -fingerprint -noout")[1]).rstrip().split('=')[ 1].replace(':', '').upper() pubkey = RunGetOutput(Openssl + " x509 -in " + filename + " -pubkey -noout")[1] keys[pubkey] = thumbprint os.rename(filename, thumbprint + ".crt") os.chmod(thumbprint + ".crt", 0o600) MyDistro.setSelinuxContext(thumbprint + '.crt', 'unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".crt" index = 1 filename = str(index) + ".prv" while os.path.isfile(filename): pubkey = RunGetOutput(Openssl + " rsa -in " + filename + " -pubout 2> /dev/null ")[1] os.rename(filename, keys[pubkey] + ".prv") os.chmod(keys[pubkey] + ".prv", 0o600) MyDistro.setSelinuxContext(keys[pubkey] + '.prv', 'unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".prv" return self class SharedConfig(object): """ Parse role endpoint server and goal state config. """ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.RdmaMacAddress = None self.RdmaIPv4Address = None self.xmlText = None def Parse(self, xmlText): """ Parse and write configuration to file SharedConfig.xml. """ LogIfVerbose(xmlText) self.reinitialize() self.xmlText = xmlText dom = xml.dom.minidom.parseString(xmlText) for a in ["SharedConfig", "Deployment", "Service", "ServiceInstance", "Incarnation", "Role", ]: if not dom.getElementsByTagName(a): Error("SharedConfig.Parse: Missing " + a) node = dom.childNodes[0] if node.localName != "SharedConfig": Error("SharedConfig.Parse: root not SharedConfig") nodes = dom.getElementsByTagName("Instance") if nodes is not None and len(nodes) != 0: node = nodes[0] if node.hasAttribute("rdmaMacAddress"): addr = node.getAttribute("rdmaMacAddress") self.RdmaMacAddress = addr[0:2] for i in range(1, 6): self.RdmaMacAddress += ":" + addr[2 * i: 2 * i + 2] if node.hasAttribute("rdmaIPv4Address"): self.RdmaIPv4Address = node.getAttribute("rdmaIPv4Address") return self def Save(self): LogIfVerbose("Save SharedConfig.xml") SetFileContents("SharedConfig.xml", self.xmlText) def InvokeTopologyConsumer(self): program = Config.get("Role.TopologyConsumer") if program != None: try: Children.append(subprocess.Popen([program, LibDir + "/SharedConfig.xml"])) except OSError as e: ErrorWithPrefix('Agent.Run', 'Exception: ' + str(e) + ' occured launching ' + program) def Process(self): global rdma_configured if not rdma_configured and self.RdmaMacAddress is not None and self.RdmaIPv4Address is not None: handler = RdmaHandler(self.RdmaMacAddress, self.RdmaIPv4Address) handler.start() rdma_configured = True self.InvokeTopologyConsumer() rdma_configured = False class RdmaConfig(object): """ configurations """ wrapper_package_name = 'msft-rdma-drivers' rmda_package_name = 'msft-lis-rdma-kmp-default' """ error code definitions """ process_success = 0 common_failed = 1 check_install_hv_utils_failed = 2 nd_driver_detect_error = 3 driver_version_not_found = 4 unknown_error = 5 package_not_found = 6 package_install_failed = 7 hv_kvp_daemon_not_started = 8 """ check_rdma_result """ UpToDate = 0 OutOfDate = 1 DriverVersionNotFound = 3 Unknown = -1 class RdmaError(Exception): def __init__(self, error_code=RdmaConfig.process_success): self.error_code = error_code class RdmaHandler(object): """ Handle rdma configuration. """ def __init__(self, mac, ip_addr, dev="/dev/hvnd_rdma", dat_conf_files=['/etc/dat.conf', '/etc/rdma/dat.conf', '/usr/local/etc/dat.conf']): self.mac = mac self.ip_addr = ip_addr self.dev = dev self.dat_conf_files = dat_conf_files self.data = ('rdmaMacAddress="{0}" rdmaIPv4Address="{1}"' '').format(self.mac, self.ip_addr) def start(self): """ Start a new thread to process rdma """ threading.Thread(target=self.process).start() def process(self): try: self.set_dat_conf() self.set_rdma_dev() self.set_rdma_ip() except RdmaError as e: Error("Failed to config rdma device: {0}".format(e)) def set_dat_conf(self): """ Agent needs to search all possible locations for dat.conf """ Log("Set dat.conf") for dat_conf_file in self.dat_conf_files: if not os.path.isfile(dat_conf_file): continue try: self.write_dat_conf(dat_conf_file) except IOError as e: raise RdmaError("Failed to write to dat.conf: {0}".format(e)) def write_dat_conf(self, dat_conf_file): Log("Write config to {0}".format(dat_conf_file)) old = (r"ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 " r"dapl.2.0 \"\S+ 0\"") new = ("ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 " "dapl.2.0 \"{0} 0\"").format(self.ip_addr) lines = GetFileContents(dat_conf_file) lines = re.sub(old, new, lines) SetFileContents(dat_conf_file, lines) def set_rdma_dev(self): """ Write config string to /dev/hvnd_rdma """ Log("Set /dev/hvnd_rdma") self.wait_rdma_dev() self.write_rdma_dev_conf() def write_rdma_dev_conf(self): Log("Write rdma config to {0}: {1}".format(self.dev, self.data)) try: with open(self.dev, "w") as c: c.write(self.data) except IOError as e: raise RdmaError("Error writing {0}, {1}".format(self.dev, e)) def wait_rdma_dev(self): Log("Wait for /dev/hvnd_rdma") retry = 0 while retry < 120: if os.path.exists(self.dev): return time.sleep(1) retry += 1 raise RdmaError("The device doesn't show up in 120 seconds") def set_rdma_ip(self): Log("Set ip addr for rdma") try: if_name = MyDistro.getInterfaceNameByMac(self.mac) # Azure is using 12 bits network mask for infiniband. MyDistro.configIpV4(if_name, self.ip_addr, 12) except Exception as e: raise RdmaError("Failed to config rdma device: {0}".format(e)) class ExtensionsConfig(object): """ Parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. Install if true, remove if it is set to false. """ # # # # # # # {"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1", # "protectedSettings":"MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR # Xh0ZW5zaW9ucwIQZi7dw+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6 # tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/X # v1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqh # kiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]} # # # https://ostcextensions.blob.core.test-cint.azure-test.net/vhds/eg-plugin7-vm.eg-plugin7-vm.eg-plugin7-vm.status?sr=b&sp=rw& # se=9999-01-01&sk=key1&sv=2012-02-12&sig=wRUIDN1x2GC06FWaetBP9sjjifOWvRzS2y2XBB4qoBU%3D def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.Extensions = None self.Plugins = None self.Util = None def Parse(self, xmlText): """ Write configuration to file ExtensionsConfig.xml. Log plugin specific activity to /var/log/azure/.//CommandExecution.log. If state is enabled: if the plugin is installed: if the new plugin's version is higher if DisallowMajorVersionUpgrade is false or if true, the version is a minor version do upgrade: download the new archive do the updateCommand. disable the old plugin and remove enable the new plugin if the new plugin's version is the same or lower: create the new .settings file from the configuration received do the enableCommand if the plugin is not installed: download/unpack archive and call the installCommand/Enable if state is disabled: call disableCommand if state is uninstall: call uninstallCommand remove old plugin directory. """ self.reinitialize() self.Util = Util() dom = xml.dom.minidom.parseString(xmlText) LogIfVerbose(xmlText) self.plugin_log_dir = '/var/log/azure' if not os.path.exists(self.plugin_log_dir): os.mkdir(self.plugin_log_dir) try: self.Extensions = dom.getElementsByTagName("Extensions") pg = dom.getElementsByTagName("Plugins") if len(pg) > 0: self.Plugins = pg[0].getElementsByTagName("Plugin") else: self.Plugins = [] incarnation = self.Extensions[0].getAttribute("goalStateIncarnation") SetFileContents('ExtensionsConfig.' + incarnation + '.xml', xmlText) except Exception as e: Error('ERROR: Error parsing ExtensionsConfig: {0}.'.format(e)) return None for p in self.Plugins: if len(p.getAttribute("location")) < 1: # this plugin is inside the PluginSettings continue p.setAttribute('restricted', 'false') previous_version = None version = p.getAttribute("version") name = p.getAttribute("name") plog_dir = self.plugin_log_dir + '/' + name + '/' + version if not os.path.exists(plog_dir): os.makedirs(plog_dir) p.plugin_log = plog_dir + '/CommandExecution.log' handler = name + '-' + version if p.getAttribute("isJson") != 'true': Error("Plugin " + name + " version: " + version + " is not a JSON Extension. Skipping.") continue Log("Found Plugin: " + name + ' version: ' + version) if p.getAttribute("state") == 'disabled' or p.getAttribute("state") == 'uninstall': # disable zip_dir = LibDir + "/" + name + '-' + version mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) if self.launchCommand(p.plugin_log, name, version, 'disableCommand') == None: self.SetHandlerState(handler, 'Enabled') Error('Unable to disable ' + name) SimpleLog(p.plugin_log, 'ERROR: Unable to disable ' + name) else: self.SetHandlerState(handler, 'Disabled') Log(name + ' is disabled') SimpleLog(p.plugin_log, name + ' is disabled') # uninstall if needed if p.getAttribute("state") == 'uninstall': if self.launchCommand(p.plugin_log, name, version, 'uninstallCommand') == None: self.SetHandlerState(handler, 'Installed') Error('Unable to uninstall ' + name) SimpleLog(p.plugin_log, 'Unable to uninstall ' + name) else: self.SetHandlerState(handler, 'NotInstalled') Log(name + ' uninstallCommand completed .') # remove the plugin Run('rm -rf ' + LibDir + '/' + name + '-' + version + '*') Log(name + '-' + version + ' extension files deleted.') SimpleLog(p.plugin_log, name + '-' + version + ' extension files deleted.') continue # state is enabled # if the same plugin exists and the version is newer or # does not exist then download and unzip the new plugin plg_dir = None latest_version_installed = LooseVersion("0.0") for item in os.listdir(LibDir): itemPath = os.path.join(LibDir, item) if os.path.isdir(itemPath) and name in item: try: # Split plugin dir name with '-' to get intalled plugin name and version sperator = item.rfind('-') if sperator < 0: continue installed_plg_name = item[0:sperator] installed_plg_version = LooseVersion(item[sperator + 1:]) # Check installed plugin name and compare installed version to get the latest version installed if installed_plg_name == name and installed_plg_version > latest_version_installed: plg_dir = itemPath previous_version = str(installed_plg_version) latest_version_installed = installed_plg_version except Exception as e: Warn("Invalid plugin dir name: {0} {1}".format(item, e)) continue if plg_dir == None or LooseVersion(version) > LooseVersion(previous_version): location = p.getAttribute("location") Log("Downloading plugin manifest: " + name + " from " + location) SimpleLog(p.plugin_log, "Downloading plugin manifest: " + name + " from " + location) self.Util.Endpoint = location.split('/')[2] Log("Plugin server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log, "Plugin server is: " + self.Util.Endpoint) manifest = self.Util.HttpGetWithoutHeaders(location, chkProxy=True) if manifest == None: Error( "Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") SimpleLog(p.plugin_log, "Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") failoverlocation = p.getAttribute("failoverlocation") self.Util.Endpoint = failoverlocation.split('/')[2] Log("Plugin failover server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log, "Plugin failover server is: " + self.Util.Endpoint) manifest = self.Util.HttpGetWithoutHeaders(failoverlocation, chkProxy=True) # if failoverlocation also fail what to do then? if manifest == None: AddExtensionEvent(name, WALAEventOperation.Download, False, 0, version, "Download mainfest fail " + failoverlocation) Log("Plugin manifest " + name + " downloading failed from failover location.") SimpleLog(p.plugin_log, "Plugin manifest " + name + " downloading failed from failover location.") filepath = LibDir + "/" + name + '.' + incarnation + '.manifest' if os.path.splitext(location)[-1] == '.xml': # if this is an xml file we may have a BOM if ord(manifest[0]) > 128 and ord(manifest[1]) > 128 and ord(manifest[2]) > 128: manifest = manifest[3:] SetFileContents(filepath, manifest) # Get the bundle url from the manifest p.setAttribute('manifestdata', manifest) man_dom = xml.dom.minidom.parseString(manifest) bundle_uri = "" for mp in man_dom.getElementsByTagName("Plugin"): if GetNodeTextData(mp.getElementsByTagName("Version")[0]) == version: bundle_uri = GetNodeTextData(mp.getElementsByTagName("Uri")[0]) break if len(mp.getElementsByTagName("DisallowMajorVersionUpgrade")): if GetNodeTextData(mp.getElementsByTagName("DisallowMajorVersionUpgrade")[ 0]) == 'true' and previous_version != None and previous_version.split('.')[ 0] != version.split('.')[0]: Log('DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') SimpleLog(p.plugin_log, 'DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') p.setAttribute('restricted', 'true') continue if len(bundle_uri) < 1: Error("Unable to fetch Bundle URI from manifest for " + name + " v " + version) SimpleLog(p.plugin_log, "Unable to fetch Bundle URI from manifest for " + name + " v " + version) continue Log("Bundle URI = " + bundle_uri) SimpleLog(p.plugin_log, "Bundle URI = " + bundle_uri) # Download the zipfile archive and save as '.zip' bundle = self.Util.HttpGetWithoutHeaders(bundle_uri, chkProxy=True) if bundle == None: AddExtensionEvent(name, WALAEventOperation.Download, True, 0, version, "Download zip fail " + bundle_uri) Error("Unable to download plugin bundle" + bundle_uri) SimpleLog(p.plugin_log, "Unable to download plugin bundle" + bundle_uri) continue AddExtensionEvent(name, WALAEventOperation.Download, True, 0, version, "Download Success") b = bytearray(bundle) filepath = LibDir + "/" + os.path.basename(bundle_uri) + '.zip' SetFileContents(filepath, b) Log("Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) SimpleLog(p.plugin_log, "Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) # unpack the archive z = zipfile.ZipFile(filepath) zip_dir = LibDir + "/" + name + '-' + version z.extractall(zip_dir) Log('Extracted ' + bundle_uri + ' to ' + zip_dir) SimpleLog(p.plugin_log, 'Extracted ' + bundle_uri + ' to ' + zip_dir) # zip no file perms in .zip so set all the scripts to +x Run("find " + zip_dir + " -type f | xargs chmod u+x ") # write out the base64 config data so the plugin can process it. mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log, 'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) # create the status and config dirs Run('mkdir -p ' + root + '/status') Run('mkdir -p ' + root + '/config') # write out the configuration data to goalStateIncarnation.settings file in the config path. config = '' seqNo = '0' if len(dom.getElementsByTagName("PluginSettings")) != 0: pslist = dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "Found RuntimeSettings for " + name + " V " + version) config = GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo = ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Log("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "No RuntimeSettings for " + name + " V " + version) SetFileContents(root + "/config/" + seqNo + ".settings", config) # create HandlerEnvironment.json handler_env = '[{ "name": "' + name + '", "seqNo": "' + seqNo + '", "version": 1.0, "handlerEnvironment": { "logFolder": "' + os.path.dirname( p.plugin_log) + '", "configFolder": "' + root + '/config", "statusFolder": "' + root + '/status", "heartbeatFile": "' + root + '/heartbeat.log"}}]' SetFileContents(root + '/HandlerEnvironment.json', handler_env) self.SetHandlerState(handler, 'NotInstalled') cmd = '' getcmd = 'installCommand' if plg_dir != None and previous_version != None and LooseVersion(version) > LooseVersion( previous_version): previous_handler = name + '-' + previous_version if self.GetHandlerState(previous_handler) != 'NotInstalled': getcmd = 'updateCommand' # disable the old plugin if it exists if self.launchCommand(p.plugin_log, name, previous_version, 'disableCommand') == None: self.SetHandlerState(previous_handler, 'Enabled') Error('Unable to disable old plugin ' + name + ' version ' + previous_version) SimpleLog(p.plugin_log, 'Unable to disable old plugin ' + name + ' version ' + previous_version) else: self.SetHandlerState(previous_handler, 'Disabled') Log(name + ' version ' + previous_version + ' is disabled') SimpleLog(p.plugin_log, name + ' version ' + previous_version + ' is disabled') try: Log("Copy status file from old plugin dir to new") old_plg_dir = plg_dir new_plg_dir = os.path.join(LibDir, "{0}-{1}".format(name, version)) old_ext_status_dir = os.path.join(old_plg_dir, "status") new_ext_status_dir = os.path.join(new_plg_dir, "status") if os.path.isdir(old_ext_status_dir): for status_file in os.listdir(old_ext_status_dir): status_file_path = os.path.join(old_ext_status_dir, status_file) if os.path.isfile(status_file_path): shutil.copy2(status_file_path, new_ext_status_dir) mrseq_file = os.path.join(old_plg_dir, "mrseq") if os.path.isfile(mrseq_file): shutil.copy(mrseq_file, new_plg_dir) except Exception as e: Error("Failed to copy status file.") isupgradeSuccess = True if getcmd == 'updateCommand': if self.launchCommand(p.plugin_log, name, version, getcmd, previous_version) == None: Error('Update failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Update failed for ' + name + '-' + version) isupgradeSuccess = False else: Log('Update complete' + name + '-' + version) SimpleLog(p.plugin_log, 'Update complete' + name + '-' + version) # if we updated - call unistall for the old plugin if self.launchCommand(p.plugin_log, name, previous_version, 'uninstallCommand') == None: self.SetHandlerState(previous_handler, 'Installed') Error('Uninstall failed for ' + name + '-' + previous_version) SimpleLog(p.plugin_log, 'Uninstall failed for ' + name + '-' + previous_version) isupgradeSuccess = False else: self.SetHandlerState(previous_handler, 'NotInstalled') Log('Uninstall complete' + previous_handler) SimpleLog(p.plugin_log, 'Uninstall complete' + name + '-' + previous_version) try: # rm old plugin dir if os.path.isdir(plg_dir): shutil.rmtree(plg_dir) Log(name + '-' + previous_version + ' extension files deleted.') SimpleLog(p.plugin_log, name + '-' + previous_version + ' extension files deleted.') except Exception as e: Error("Failed to remove old plugin directory") AddExtensionEvent(name, WALAEventOperation.Upgrade, isupgradeSuccess, 0, previous_version) else: # run install if self.launchCommand(p.plugin_log, name, version, getcmd) == None: self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Installed') Log('Installation completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation completed for ' + name + '-' + version) # end if plg_dir == none or version > = prev # change incarnation of settings file so it knows how to name status... zip_dir = LibDir + "/" + name + '-' + version mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log, 'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) config = '' seqNo = '0' if len(dom.getElementsByTagName("PluginSettings")) != 0: try: pslist = dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") except: Error('Error parsing ExtensionsConfig.') SimpleLog(p.plugin_log, 'Error parsing ExtensionsConfig.') continue for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "Found RuntimeSettings for " + name + " V " + version) config = GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo = ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Error("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "No RuntimeSettings for " + name + " V " + version) SetFileContents(root + "/config/" + seqNo + ".settings", config) # state is still enable if (self.GetHandlerState(handler) == 'NotInstalled'): # run install first if true if self.launchCommand(p.plugin_log, name, version, 'installCommand') == None: self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Installed') Log('Installation completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation completed for ' + name + '-' + version) if (self.GetHandlerState(handler) != 'NotInstalled'): if self.launchCommand(p.plugin_log, name, version, 'enableCommand') == None: self.SetHandlerState(handler, 'Installed') Error('Enable failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Enable failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Enabled') Log('Enable completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Enable completed for ' + name + '-' + version) # this plugin processing is complete Log('Processing completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Processing completed for ' + name + '-' + version) # end plugin processing loop Log('Finished processing ExtensionsConfig.xml') try: SimpleLog(p.plugin_log, 'Finished processing ExtensionsConfig.xml') except: pass return self def launchCommand(self, plugin_log, name, version, command, prev_version=None): commandToEventOperation = { "installCommand": WALAEventOperation.Install, "uninstallCommand": WALAEventOperation.UnIsntall, "updateCommand": WALAEventOperation.Upgrade, "enableCommand": WALAEventOperation.Enable, "disableCommand": WALAEventOperation.Disable, } isSuccess = True start = datetime.datetime.now() r = self.__launchCommandWithoutEventLog(plugin_log, name, version, command, prev_version) if r == None: isSuccess = False Duration = int((datetime.datetime.now() - start).seconds) if commandToEventOperation.get(command): AddExtensionEvent(name, commandToEventOperation[command], isSuccess, Duration, version) return r def __launchCommandWithoutEventLog(self, plugin_log, name, version, command, prev_version=None): # get the manifest and read the command mfile = None zip_dir = LibDir + "/" + name + '-' + version for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(plugin_log, 'HandlerManifest.json not found.') return None manifest = GetFileContents(mfile) try: jsn = json.loads(manifest) except: Error('Error parsing HandlerManifest.json.') SimpleLog(plugin_log, 'Error parsing HandlerManifest.json.') return None if type(jsn) == list: jsn = jsn[0] if jsn.has_key('handlerManifest'): cmd = jsn['handlerManifest'][command] else: Error('Key handlerManifest not found. Handler cannot be installed.') SimpleLog(plugin_log, 'Key handlerManifest not found. Handler cannot be installed.') if len(cmd) == 0: Error('Unable to read ' + command) SimpleLog(plugin_log, 'Unable to read ' + command) return None # for update we send the path of the old installation arg = '' if prev_version != None: arg = ' ' + LibDir + '/' + name + '-' + prev_version dirpath = os.path.dirname(mfile) LogIfVerbose('Command is ' + dirpath + '/' + cmd) # launch pid = None try: child = subprocess.Popen(dirpath + '/' + cmd + arg, shell=True, cwd=dirpath, stdout=subprocess.PIPE) except Exception as e: Error('Exception launching ' + cmd + str(e)) SimpleLog(plugin_log, 'Exception launching ' + cmd + str(e)) pid = child.pid if pid == None or pid < 1: ExtensionChildren.append((-1, root)) Error('Error launching ' + cmd + '.') SimpleLog(plugin_log, 'Error launching ' + cmd + '.') else: ExtensionChildren.append((pid, root)) Log("Spawned " + cmd + " PID " + str(pid)) SimpleLog(plugin_log, "Spawned " + cmd + " PID " + str(pid)) # wait until install/upgrade is finished timeout = 300 # 5 minutes retry = timeout / 5 while retry > 0 and child.poll() == None: LogIfVerbose(cmd + ' still running with PID ' + str(pid)) time.sleep(5) retry -= 1 if retry == 0: Error('Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) SimpleLog(plugin_log, 'Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) os.kill(pid, 9) return None code = child.wait() if code == None or code != 0: Error('Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') SimpleLog(plugin_log, 'Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') return None Log(command + ' completed.') SimpleLog(plugin_log, command + ' completed.') return 0 def ReportHandlerStatus(self): """ Collect all status reports. """ # { "version": "1.0", "timestampUTC": "2014-03-31T21:28:58Z", # "aggregateStatus": { # "guestAgentStatus": { "version": "2.0.4PRE", "status": "Ready", "formattedMessage": { "lang": "en-US", "message": "GuestAgent is running and accepting new configurations." } }, # "handlerAggregateStatus": [{ # "handlerName": "ExampleHandlerLinux", "handlerVersion": "1.0", "status": "Ready", "runtimeSettingsStatus": { # "sequenceNumber": "2", "settingsStatus": { "timestampUTC": "2014-03-31T23:46:00Z", "status": { "name": "ExampleHandlerLinux", "operation": "Command Execution Finished", "configurationAppliedTime": "2014-03-31T23:46:00Z", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Finished executing command" }, # "substatus": [ # { "name": "StdOut", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Goodbye world!" } }, # { "name": "StdErr", "status": "success", "formattedMessage": { "lang": "en-US", "message": "" } } # ] # } } } } # ] # }} try: incarnation = self.Extensions[0].getAttribute("goalStateIncarnation") except: Error('Error parsing ExtensionsConfig. Unable to send status reports') return -1 status = '' statuses = '' for p in self.Plugins: if p.getAttribute("state") == 'uninstall' or p.getAttribute("restricted") == 'true': continue version = p.getAttribute("version") name = p.getAttribute("name") if p.getAttribute("isJson") != 'true': LogIfVerbose("Plugin " + name + " version: " + version + " is not a JSON Extension. Skipping.") continue reportHeartbeat = False if len(p.getAttribute("manifestdata")) < 1: Error("Failed to get manifestdata.") else: reportHeartbeat = json.loads(p.getAttribute("manifestdata"))[0]['handlerManifest']['reportHeartbeat'] if len(statuses) > 0: statuses += ',' statuses += self.GenerateAggStatus(name, version, reportHeartbeat) tstamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) # header # agent state if provisioned == False: if provisionError == None: agent_state = 'Provisioning' agent_msg = 'Guest Agent is starting.' else: agent_state = 'Provisioning Error.' agent_msg = provisionError else: agent_state = 'Ready' agent_msg = 'GuestAgent is running and accepting new configurations.' status = '{"version":"1.0","timestampUTC":"' + tstamp + '","aggregateStatus":{"guestAgentStatus":{"version":"' + GuestAgentVersion + '","status":"' + agent_state + '","formattedMessage":{"lang":"en-US","message":"' + agent_msg + '"}},"handlerAggregateStatus":[' + statuses + ']}}' try: uri = GetNodeTextData(self.Extensions[0].getElementsByTagName("StatusUploadBlob")[0]).replace('&', '&') except: Error('Error parsing ExtensionsConfig. Unable to send status reports') return -1 LogIfVerbose('Status report ' + status + ' sent to ' + uri) return UploadStatusBlob(uri, status.encode("utf-8")) def GetCurrentSequenceNumber(self, plugin_base_dir): """ Get the settings file with biggest file number in config folder """ config_dir = os.path.join(plugin_base_dir, 'config') seq_no = 0 for subdir, dirs, files in os.walk(config_dir): for file in files: try: cur_seq_no = int(os.path.basename(file).split('.')[0]) if cur_seq_no > seq_no: seq_no = cur_seq_no except ValueError: continue return str(seq_no) def GenerateAggStatus(self, name, version, reportHeartbeat=False): """ Generate the status which Azure can understand by the status and heartbeat reported by extension """ plugin_base_dir = LibDir + '/' + name + '-' + version + '/' current_seq_no = self.GetCurrentSequenceNumber(plugin_base_dir) status_file = os.path.join(plugin_base_dir, 'status/', current_seq_no + '.status') heartbeat_file = os.path.join(plugin_base_dir, 'heartbeat.log') handler_state_file = os.path.join(plugin_base_dir, 'config', 'HandlerState') agg_state = 'NotReady' handler_state = None status_obj = None status_code = None formatted_message = None localized_message = None if os.path.exists(handler_state_file): handler_state = GetFileContents(handler_state_file).lower() if HandlerStatusToAggStatus.has_key(handler_state): agg_state = HandlerStatusToAggStatus[handler_state] if reportHeartbeat: if os.path.exists(heartbeat_file): d = int(time.time() - os.stat(heartbeat_file).st_mtime) if d > 600: # not updated for more than 10 min agg_state = 'Unresponsive' else: try: heartbeat = json.loads(GetFileContents(heartbeat_file))[0]["heartbeat"] agg_state = heartbeat.get("status") status_code = heartbeat.get("code") formatted_message = heartbeat.get("formattedMessage") localized_message = heartbeat.get("message") except: Error("Incorrect heartbeat file. Ignore it. ") else: agg_state = 'Unresponsive' # get status file reported by extension if os.path.exists(status_file): # raw status generated by extension is an array, get the first item and remove the unnecessary element try: status_obj = json.loads(GetFileContents(status_file))[0] del status_obj["version"] except: Error("Incorrect status file. Will NOT settingsStatus in settings. ") agg_status_obj = {"handlerName": name, "handlerVersion": version, "status": agg_state, "runtimeSettingsStatus": {"sequenceNumber": current_seq_no}} if status_obj: agg_status_obj["runtimeSettingsStatus"]["settingsStatus"] = status_obj if status_code != None: agg_status_obj["code"] = status_code if formatted_message: agg_status_obj["formattedMessage"] = formatted_message if localized_message: agg_status_obj["message"] = localized_message agg_status_string = json.dumps(agg_status_obj) LogIfVerbose("Handler Aggregated Status:" + agg_status_string) return agg_status_string def SetHandlerState(self, handler, state=''): zip_dir = LibDir + "/" + handler mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('SetHandlerState(): HandlerManifest.json not found, cannot set HandlerState.') return None Log("SetHandlerState: " + handler + ", " + state) return SetFileContents(os.path.dirname(mfile) + '/config/HandlerState', state) def GetHandlerState(self, handler): handlerState = GetFileContents(handler + '/config/HandlerState') if (handlerState): return handlerState.rstrip('\r\n') else: return 'NotInstalled' class HostingEnvironmentConfig(object): """ Parse Hosting enviromnet config and store in HostingEnvironmentConfig.xml """ # # # # # # # # # # # # # # # # # # # # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset Members. """ self.StoredCertificates = None self.Deployment = None self.Incarnation = None self.Role = None self.HostingEnvironmentSettings = None self.ApplicationSettings = None self.Certificates = None self.ResourceReferences = None def Parse(self, xmlText): """ Parse and create HostingEnvironmentConfig.xml. """ self.reinitialize() SetFileContents("HostingEnvironmentConfig.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in ["HostingEnvironmentConfig", "Deployment", "Service", "ServiceInstance", "Incarnation", "Role", ]: if not dom.getElementsByTagName(a): Error("HostingEnvironmentConfig.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "HostingEnvironmentConfig": Error("HostingEnvironmentConfig.Parse: root not HostingEnvironmentConfig") return None self.ApplicationSettings = dom.getElementsByTagName("Setting") self.Certificates = dom.getElementsByTagName("StoredCertificate") return self def DecryptPassword(self, e): """ Return decrypted password. """ SetFileContents("password.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"password.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"password.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + textwrap.fill(e, 64)) return RunGetOutput(Openssl + " cms -decrypt -in password.p7m -inkey Certificates.pem -recip Certificates.pem")[ 1] def ActivateResourceDisk(self): return MyDistro.ActivateResourceDisk() def Process(self): """ Execute ActivateResourceDisk in separate thread. Create the user account. Launch ConfigurationConsumer if specified in the config. """ no_thread = False if DiskActivated == False: for m in inspect.getmembers(MyDistro): if 'ActivateResourceDiskNoThread' in m: no_thread = True break if no_thread == True: MyDistro.ActivateResourceDiskNoThread() else: diskThread = threading.Thread(target=self.ActivateResourceDisk) diskThread.start() User = None Pass = None Expiration = None Thumbprint = None for b in self.ApplicationSettings: sname = b.getAttribute("name") svalue = b.getAttribute("value") if User != None and Pass != None: if User != "root" and User != "" and Pass != "": CreateAccount(User, Pass, Expiration, Thumbprint) else: Error("Not creating user account: " + User) for c in self.Certificates: csha1 = c.getAttribute("certificateId").split(':')[1].upper() if os.path.isfile(csha1 + ".prv"): Log("Private key with thumbprint: " + csha1 + " was retrieved.") if os.path.isfile(csha1 + ".crt"): Log("Public cert with thumbprint: " + csha1 + " was retrieved.") program = Config.get("Role.ConfigurationConsumer") if program != None: try: Children.append(subprocess.Popen([program, LibDir + "/HostingEnvironmentConfig.xml"])) except OSError as e: ErrorWithPrefix('HostingEnvironmentConfig.Process', 'Exception: ' + str(e) + ' occured launching ' + program) class GoalState(Util): """ Primary container for all configuration except OvfXml. Encapsulates http communication with endpoint server. Initializes and populates: self.HostingEnvironmentConfig self.SharedConfig self.ExtensionsConfig self.Certificates """ # # # 2010-12-15 # 1 # # Started # # 16001 # # # # c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2 # # # MachineRole_IN_0 # Started # # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=config&type=hostingEnvironmentConfig&incarnation=1 # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=config&type=sharedConfig&incarnation=1 # http://10.115.153.40:80/machine/c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2/MachineRole%5FIN%5F0?comp=certificates&incarnation=1 # http://100.67.238.230:80/machine/9c87aa94-3bda-45e3-b2b7-0eb0fca7baff/1552dd64dc254e6884f8d5b8b68aa18f.eg%2Dplug%2Dvm?comp=config&type=extensionsConfig&incarnation=2 # http://100.67.238.230:80/machine/9c87aa94-3bda-45e3-b2b7-0eb0fca7baff/1552dd64dc254e6884f8d5b8b68aa18f.eg%2Dplug%2Dvm?comp=config&type=fullConfig&incarnation=2 # # # # # # # There is only one Role for VM images. # # Of primary interest is: # LBProbePorts -- an http server needs to run here # We also note Container/ContainerID and RoleInstance/InstanceId to form the health report. # And of course, Incarnation # def __init__(self, Agent): self.Agent = Agent self.Endpoint = Agent.Endpoint self.TransportCert = Agent.TransportCert self.reinitialize() def reinitialize(self): self.Incarnation = None # integer self.ExpectedState = None # "Started" self.HostingEnvironmentConfigUrl = None self.HostingEnvironmentConfigXml = None self.HostingEnvironmentConfig = None self.SharedConfigUrl = None self.SharedConfigXml = None self.SharedConfig = None self.CertificatesUrl = None self.CertificatesXml = None self.Certificates = None self.ExtensionsConfigUrl = None self.ExtensionsConfigXml = None self.ExtensionsConfig = None self.RoleInstanceId = None self.ContainerId = None self.LoadBalancerProbePort = None # integer, ?list of integers def Parse(self, xmlText): """ Request configuration data from endpoint server. Parse and populate contained configuration objects. Calls Certificates().Parse() Calls SharedConfig().Parse Calls ExtensionsConfig().Parse Calls HostingEnvironmentConfig().Parse """ self.reinitialize() LogIfVerbose(xmlText) node = xml.dom.minidom.parseString(xmlText).childNodes[0] if node.localName != "GoalState": Error("GoalState.Parse: root not GoalState") return None for a in node.childNodes: if a.nodeType == node.ELEMENT_NODE: if a.localName == "Incarnation": self.Incarnation = GetNodeTextData(a) elif a.localName == "Machine": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE: if b.localName == "ExpectedState": self.ExpectedState = GetNodeTextData(b) Log("ExpectedState: " + self.ExpectedState) elif b.localName == "LBProbePorts": for c in b.childNodes: if c.nodeType == node.ELEMENT_NODE and c.localName == "Port": self.LoadBalancerProbePort = int(GetNodeTextData(c)) elif a.localName == "Container": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE: if b.localName == "ContainerId": self.ContainerId = GetNodeTextData(b) Log("ContainerId: " + self.ContainerId) elif b.localName == "RoleInstanceList": for c in b.childNodes: if c.localName == "RoleInstance": for d in c.childNodes: if d.nodeType == node.ELEMENT_NODE: if d.localName == "InstanceId": self.RoleInstanceId = GetNodeTextData(d) Log("RoleInstanceId: " + self.RoleInstanceId) elif d.localName == "State": pass elif d.localName == "Configuration": for e in d.childNodes: if e.nodeType == node.ELEMENT_NODE: LogIfVerbose(e.localName) if e.localName == "HostingEnvironmentConfig": self.HostingEnvironmentConfigUrl = GetNodeTextData(e) LogIfVerbose( "HostingEnvironmentConfigUrl:" + self.HostingEnvironmentConfigUrl) self.HostingEnvironmentConfigXml = self.HttpGetWithHeaders( self.HostingEnvironmentConfigUrl) self.HostingEnvironmentConfig = HostingEnvironmentConfig().Parse( self.HostingEnvironmentConfigXml) elif e.localName == "SharedConfig": self.SharedConfigUrl = GetNodeTextData(e) LogIfVerbose("SharedConfigUrl:" + self.SharedConfigUrl) self.SharedConfigXml = self.HttpGetWithHeaders( self.SharedConfigUrl) self.SharedConfig = SharedConfig().Parse( self.SharedConfigXml) self.SharedConfig.Save() elif e.localName == "ExtensionsConfig": self.ExtensionsConfigUrl = GetNodeTextData(e) LogIfVerbose( "ExtensionsConfigUrl:" + self.ExtensionsConfigUrl) self.ExtensionsConfigXml = self.HttpGetWithHeaders( self.ExtensionsConfigUrl) elif e.localName == "Certificates": self.CertificatesUrl = GetNodeTextData(e) LogIfVerbose("CertificatesUrl:" + self.CertificatesUrl) self.CertificatesXml = self.HttpSecureGetWithHeaders( self.CertificatesUrl, self.TransportCert) self.Certificates = Certificates().Parse( self.CertificatesXml) if self.Incarnation == None: Error("GoalState.Parse: Incarnation missing") return None if self.ExpectedState == None: Error("GoalState.Parse: ExpectedState missing") return None if self.RoleInstanceId == None: Error("GoalState.Parse: RoleInstanceId missing") return None if self.ContainerId == None: Error("GoalState.Parse: ContainerId missing") return None SetFileContents("GoalState." + self.Incarnation + ".xml", xmlText) return self def Process(self): """ Calls HostingEnvironmentConfig.Process() """ LogIfVerbose("Process goalstate") self.HostingEnvironmentConfig.Process() self.SharedConfig.Process() class OvfEnv(object): """ Read, and process provisioning info from provisioning file OvfEnv.xml """ # # # # # 1.0 # # LinuxProvisioningConfiguration # HostName # UserName # UserPassword # false # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/authorized_keys # # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/id_rsa # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.WaNs = "http://schemas.microsoft.com/windowsazure" self.OvfNs = "http://schemas.dmtf.org/ovf/environment/1" self.MajorVersion = 1 self.MinorVersion = 0 self.ComputerName = None self.AdminPassword = None self.UserName = None self.UserPassword = None self.CustomData = None self.DisableSshPasswordAuthentication = True self.SshPublicKeys = [] self.SshKeyPairs = [] def Parse(self, xmlText, isDeprovision=False): """ Parse xml tree, retreiving user and ssh key information. Return self. """ self.reinitialize() LogIfVerbose(re.sub(".*?<", "*<", xmlText)) dom = xml.dom.minidom.parseString(xmlText) if len(dom.getElementsByTagNameNS(self.OvfNs, "Environment")) != 1: Error("Unable to parse OVF XML.") section = None newer = False for p in dom.getElementsByTagNameNS(self.WaNs, "ProvisioningSection"): for n in p.childNodes: if n.localName == "Version": verparts = GetNodeTextData(n).split('.') major = int(verparts[0]) minor = int(verparts[1]) if major > self.MajorVersion: newer = True if major != self.MajorVersion: break if minor > self.MinorVersion: newer = True section = p if newer == True: Warn("Newer provisioning configuration detected. Please consider updating waagent.") if section == None: Error("Could not find ProvisioningSection with major version=" + str(self.MajorVersion)) return None self.ComputerName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "HostName")[0]) self.UserName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserName")[0]) if isDeprovision == True: return self try: self.UserPassword = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserPassword")[0]) except: pass CDSection = None try: CDSection = section.getElementsByTagNameNS(self.WaNs, "CustomData") if len(CDSection) > 0: self.CustomData = GetNodeTextData(CDSection[0]) if len(self.CustomData) > 0: SetFileContents(LibDir + '/CustomData', bytearray(MyDistro.translateCustomData(self.CustomData))) Log('Wrote ' + LibDir + '/CustomData') else: Error(' contains no data!') except Exception as e: Error(str(e) + ' occured creating ' + LibDir + '/CustomData') disableSshPass = section.getElementsByTagNameNS(self.WaNs, "DisableSshPasswordAuthentication") if len(disableSshPass) != 0: self.DisableSshPasswordAuthentication = (GetNodeTextData(disableSshPass[0]).lower() == "true") for pkey in section.getElementsByTagNameNS(self.WaNs, "PublicKey"): LogIfVerbose(repr(pkey)) fp = None path = None for c in pkey.childNodes: if c.localName == "Fingerprint": fp = GetNodeTextData(c).upper() LogIfVerbose(fp) if c.localName == "Path": path = GetNodeTextData(c) LogIfVerbose(path) self.SshPublicKeys += [[fp, path]] for keyp in section.getElementsByTagNameNS(self.WaNs, "KeyPair"): fp = None path = None LogIfVerbose(repr(keyp)) for c in keyp.childNodes: if c.localName == "Fingerprint": fp = GetNodeTextData(c).upper() LogIfVerbose(fp) if c.localName == "Path": path = GetNodeTextData(c) LogIfVerbose(path) self.SshKeyPairs += [[fp, path]] return self def PrepareDir(self, filepath): """ Create home dir for self.UserName Change owner and return path. """ home = MyDistro.GetHome() # Expand HOME variable if present in path path = os.path.normpath(filepath.replace("$HOME", home)) if (path.startswith("/") == False) or (path.endswith("/") == True): return None dir = path.rsplit('/', 1)[0] if dir != "": CreateDir(dir, "root", 0o700) if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(dir, self.UserName) return path def NumberToBytes(self, i): """ Pack number into bytes. Retun as string. """ result = [] while i: result.append(chr(i & 0xFF)) i >>= 8 result.reverse() return ''.join(result) def BitsToString(self, a): """ Return string representation of bits in a. """ index = 7 s = "" c = 0 for bit in a: c = c | (bit << index) index = index - 1 if index == -1: s = s + struct.pack('>B', c) c = 0 index = 7 return s def OpensslToSsh(self, file): """ Return base-64 encoded key appropriate for ssh. """ from pyasn1.codec.der import decoder as der_decoder try: f = open(file).read().replace('\n', '').split("KEY-----")[1].split('-')[0] k = der_decoder.decode(self.BitsToString(der_decoder.decode(base64.b64decode(f))[0][1]))[0] n = k[0] e = k[1] keydata = "" keydata += struct.pack('>I', len("ssh-rsa")) keydata += "ssh-rsa" keydata += struct.pack('>I', len(self.NumberToBytes(e))) keydata += self.NumberToBytes(e) keydata += struct.pack('>I', len(self.NumberToBytes(n)) + 1) keydata += "\0" keydata += self.NumberToBytes(n) except Exception as e: print("OpensslToSsh: Exception " + str(e)) return None return "ssh-rsa " + base64.b64encode(keydata) + "\n" def Process(self): """ Process all certificate and key info. DisableSshPasswordAuthentication if configured. CreateAccount(user) Wait for WaAgent.EnvMonitor.IsHostnamePublished(). Restart ssh service. """ error = None if self.ComputerName == None: return "Error: Hostname missing" error = WaAgent.EnvMonitor.SetHostName(self.ComputerName) if error: return error if self.DisableSshPasswordAuthentication: filepath = "/etc/ssh/sshd_config" # Disable RFC 4252 and RFC 4256 authentication schemes. ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not (a.startswith("PasswordAuthentication") or a.startswith("ChallengeResponseAuthentication")), GetFileContents(filepath).split( '\n'))) + "\nPasswordAuthentication no\nChallengeResponseAuthentication no\n") Log("Disabled SSH password-based authentication methods.") if self.AdminPassword != None: MyDistro.changePass('root', self.AdminPassword) if self.UserName != None: error = MyDistro.CreateAccount(self.UserName, self.UserPassword, None, None) sel = MyDistro.isSelinuxRunning() if sel: MyDistro.setSelinuxEnforce(0) home = MyDistro.GetHome() for pkey in self.SshPublicKeys: Log("Deploy public key:{0}".format(pkey[0])) if not os.path.isfile(pkey[0] + ".crt"): Error("PublicKey not found: " + pkey[0]) error = "Failed to deploy public key (0x09)." continue path = self.PrepareDir(pkey[1]) if path == None: Error("Invalid path: " + pkey[1] + " for PublicKey: " + pkey[0]) error = "Invalid path for public key (0x03)." continue Run(Openssl + " x509 -in " + pkey[0] + ".crt -noout -pubkey > " + pkey[0] + ".pub") MyDistro.setSelinuxContext(pkey[0] + '.pub', 'unconfined_u:object_r:ssh_home_t:s0') MyDistro.sshDeployPublicKey(pkey[0] + '.pub', path) MyDistro.setSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0') if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(path, self.UserName) for keyp in self.SshKeyPairs: Log("Deploy key pair:{0}".format(keyp[0])) if not os.path.isfile(keyp[0] + ".prv"): Error("KeyPair not found: " + keyp[0]) error = "Failed to deploy key pair (0x0A)." continue path = self.PrepareDir(keyp[1]) if path == None: Error("Invalid path: " + keyp[1] + " for KeyPair: " + keyp[0]) error = "Invalid path for key pair (0x05)." continue SetFileContents(path, GetFileContents(keyp[0] + ".prv")) os.chmod(path, 0o600) Run("ssh-keygen -y -f " + keyp[0] + ".prv > " + path + ".pub") MyDistro.setSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0') MyDistro.setSelinuxContext(path + '.pub', 'unconfined_u:object_r:ssh_home_t:s0') if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ChangeOwner(path, self.UserName) ChangeOwner(path + ".pub", self.UserName) if sel: MyDistro.setSelinuxEnforce(1) while not WaAgent.EnvMonitor.IsHostnamePublished(): time.sleep(1) MyDistro.restartSshService() return error class WALAEvent(object): def __init__(self): self.providerId = "" self.eventId = 1 self.OpcodeName = "" self.KeywordName = "" self.TaskName = "" self.TenantName = "" self.RoleName = "" self.RoleInstanceName = "" self.ContainerId = "" self.ExecutionMode = "IAAS" self.OSVersion = "" self.GAVersion = "" self.RAM = 0 self.Processors = 0 def ToXml(self): strEventid = u''.format(self.eventId) strProviderid = u''.format(self.providerId) strRecordFormat = u'' strRecordNoQuoteFormat = u'' strMtStr = u'mt:wstr' strMtUInt64 = u'mt:uint64' strMtBool = u'mt:bool' strMtFloat = u'mt:float64' strEventsData = u"" for attName in self.__dict__: if attName in ["eventId", "filedCount", "providerId"]: continue attValue = self.__dict__[attName] if type(attValue) is int: strEventsData += strRecordFormat.format(attName, attValue, strMtUInt64) continue if type(attValue) is str: attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData += strRecordNoQuoteFormat.format(attName, attValue, strMtStr) continue if str(type(attValue)).count("'unicode'") > 0: attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData += strRecordNoQuoteFormat.format(attName, attValue, strMtStr) continue if type(attValue) is bool: strEventsData += strRecordFormat.format(attName, attValue, strMtBool) continue if type(attValue) is float: strEventsData += strRecordFormat.format(attName, attValue, strMtFloat) continue Log("Warning: property " + attName + ":" + str(type(attValue)) + ":type" + str( type(attValue)) + "Can't convert to events data:" + ":type not supported") return u"{0}{1}{2}".format(strProviderid, strEventid, strEventsData) def Save(self): eventfolder = LibDir + "/events" if not os.path.exists(eventfolder): os.mkdir(eventfolder) os.chmod(eventfolder, 0o700) if len(os.listdir(eventfolder)) > 1000: raise Exception("WriteToFolder:Too many file under " + eventfolder + " exit") filename = os.path.join(eventfolder, str(int(time.time() * 1000000))) with open(filename + ".tmp", 'wb+') as hfile: hfile.write(self.ToXml().encode("utf-8")) os.rename(filename + ".tmp", filename + ".tld") class WALAEventOperation: HeartBeat = "HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" def AddExtensionEvent(name, op, isSuccess, duration=0, version="1.0", message="", type="", isInternal=False): event = ExtensionEvent() event.Name = name event.Version = version event.IsInternal = isInternal event.Operation = op event.OperationSuccess = isSuccess event.Message = message event.Duration = duration event.ExtensionType = type try: event.Save() except: Error("Error " + traceback.format_exc()) class ExtensionEvent(WALAEvent): def __init__(self): WALAEvent.__init__(self) self.eventId = 1 self.providerId = "69B669B9-4AF8-4C50-BDC4-6006FA76E975" self.Name = "" self.Version = "" self.IsInternal = False self.Operation = "" self.OperationSuccess = True self.ExtensionType = "" self.Message = "" self.Duration = 0 class WALAEventMonitor(WALAEvent): def __init__(self, postMethod): WALAEvent.__init__(self) self.post = postMethod self.sysInfo = {} self.eventdir = LibDir + "/events" self.issysteminfoinitilized = False def StartEventsLoop(self): eventThread = threading.Thread(target=self.EventsLoop) eventThread.setDaemon(True) eventThread.start() def EventsLoop(self): LastReportHeartBeatTime = datetime.datetime.min try: while (True): if (datetime.datetime.now() - LastReportHeartBeatTime) > datetime.timedelta(hours=12): LastReportHeartBeatTime = datetime.datetime.now() AddExtensionEvent(op=WALAEventOperation.HeartBeat, name="WALA", isSuccess=True) self.postNumbersInOneLoop = 0 self.CollectAndSendWALAEvents() time.sleep(60) except: Error("Exception in events loop:" + traceback.format_exc()) def SendEvent(self, providerid, events): dataFormat = u'{1}' \ '' data = dataFormat.format(providerid, events) self.post("/machine/?comp=telemetrydata", data) def CollectAndSendWALAEvents(self): if not os.path.exists(self.eventdir): return # Throtting, can't send more than 3 events in 15 seconds eventSendNumber = 0 eventFiles = os.listdir(self.eventdir) events = {} for file in eventFiles: if not file.endswith(".tld"): continue with open(os.path.join(self.eventdir, file), "rb") as hfile: # if fail to open or delete the file, throw exception xmlStr = hfile.read().decode("utf-8", 'ignore') os.remove(os.path.join(self.eventdir, file)) params = "" eventid = "" providerid = "" # if exception happen during process an event, catch it and continue try: xmlStr = self.AddSystemInfo(xmlStr) for node in xml.dom.minidom.parseString(xmlStr.encode("utf-8")).childNodes[0].childNodes: if node.tagName == "Param": params += node.toxml() if node.tagName == "Event": eventid = node.getAttribute("id") if node.tagName == "Provider": providerid = node.getAttribute("id") except: Error(traceback.format_exc()) continue if len(params) == 0 or len(eventid) == 0 or len(providerid) == 0: Error("Empty filed in params:" + params + " event id:" + eventid + " provider id:" + providerid) continue eventstr = u''.format(eventid, params) if not events.get(providerid): events[providerid] = "" if len(events[providerid]) > 0 and len(events.get(providerid) + eventstr) >= 63 * 1024: eventSendNumber += 1 self.SendEvent(providerid, events.get(providerid)) if eventSendNumber % 3 == 0: time.sleep(15) events[providerid] = "" if len(eventstr) >= 63 * 1024: Error("Signle event too large abort " + eventstr[:300]) continue events[providerid] = events.get(providerid) + eventstr for key in events.keys(): if len(events[key]) > 0: eventSendNumber += 1 self.SendEvent(key, events[key]) if eventSendNumber % 3 == 0: time.sleep(15) def AddSystemInfo(self, eventData): if not self.issysteminfoinitilized: self.issysteminfoinitilized = True try: self.sysInfo["OSVersion"] = platform.system() + ":" + "-".join(DistInfo(1)) + ":" + platform.release() self.sysInfo["GAVersion"] = GuestAgentVersion self.sysInfo["RAM"] = MyDistro.getTotalMemory() self.sysInfo["Processors"] = MyDistro.getProcessorCores() sharedConfig = xml.dom.minidom.parse("/var/lib/waagent/SharedConfig.xml").childNodes[0] hostEnvConfig = xml.dom.minidom.parse("/var/lib/waagent/HostingEnvironmentConfig.xml").childNodes[0] gfiles = RunGetOutput("ls -t /var/lib/waagent/GoalState.*.xml")[1] goalStateConfi = xml.dom.minidom.parse(gfiles.split("\n")[0]).childNodes[0] self.sysInfo["TenantName"] = hostEnvConfig.getElementsByTagName("Deployment")[0].getAttribute("name") self.sysInfo["RoleName"] = hostEnvConfig.getElementsByTagName("Role")[0].getAttribute("name") self.sysInfo["RoleInstanceName"] = sharedConfig.getElementsByTagName("Instance")[0].getAttribute("id") self.sysInfo["ContainerId"] = goalStateConfi.getElementsByTagName("ContainerId")[0].childNodes[ 0].nodeValue except: Error(traceback.format_exc()) eventObject = xml.dom.minidom.parseString(eventData.encode("utf-8")).childNodes[0] for node in eventObject.childNodes: if node.tagName == "Param": name = node.getAttribute("Name") if self.sysInfo.get(name): node.setAttribute("Value", xml.sax.saxutils.escape(str(self.sysInfo[name]))) return eventObject.toxml() class Agent(Util): """ Primary object container for the provisioning process. """ def __init__(self): self.GoalState = None self.Endpoint = None self.LoadBalancerProbeServer = None self.HealthReportCounter = 0 self.TransportCert = "" self.EnvMonitor = None self.SendData = None self.DhcpResponse = None def CheckVersions(self): """ Query endpoint server for wire protocol version. Fail if our desired protocol version is not seen. """ # # # # 2010-12-15 # # # 2010-12-15 # 2010-28-10 # # global ProtocolVersion protocolVersionSeen = False node = xml.dom.minidom.parseString(self.HttpGetWithoutHeaders("/?comp=versions")).childNodes[0] if node.localName != "Versions": Error("CheckVersions: root not Versions") return False for a in node.childNodes: if a.nodeType == node.ELEMENT_NODE and a.localName == "Supported": for b in a.childNodes: if b.nodeType == node.ELEMENT_NODE and b.localName == "Version": v = GetNodeTextData(b) LogIfVerbose("Fabric supported wire protocol version: " + v) if v == ProtocolVersion: protocolVersionSeen = True if a.nodeType == node.ELEMENT_NODE and a.localName == "Preferred": v = GetNodeTextData(a.getElementsByTagName("Version")[0]) Log("Fabric preferred wire protocol version: " + v) if not protocolVersionSeen: Warn("Agent supported wire protocol version: " + ProtocolVersion + " was not advertised by Fabric.") else: Log("Negotiated wire protocol version: " + ProtocolVersion) return True def Unpack(self, buffer, offset, range): """ Unpack bytes into python values. """ result = 0 for i in range: result = (result << 8) | Ord(buffer[offset + i]) return result def UnpackLittleEndian(self, buffer, offset, length): """ Unpack little endian bytes into python values. """ return self.Unpack(buffer, offset, list(range(length - 1, -1, -1))) def UnpackBigEndian(self, buffer, offset, length): """ Unpack big endian bytes into python values. """ return self.Unpack(buffer, offset, list(range(0, length))) def HexDump3(self, buffer, offset, length): """ Dump range of buffer in formatted hex. """ return ''.join(['%02X' % Ord(char) for char in buffer[offset:offset + length]]) def HexDump2(self, buffer): """ Dump buffer in formatted hex. """ return self.HexDump3(buffer, 0, len(buffer)) def BuildDhcpRequest(self): """ Build DHCP request string. """ # # typedef struct _DHCP { # UINT8 Opcode; /* op: BOOTREQUEST or BOOTREPLY */ # UINT8 HardwareAddressType; /* htype: ethernet */ # UINT8 HardwareAddressLength; /* hlen: 6 (48 bit mac address) */ # UINT8 Hops; /* hops: 0 */ # UINT8 TransactionID[4]; /* xid: random */ # UINT8 Seconds[2]; /* secs: 0 */ # UINT8 Flags[2]; /* flags: 0 or 0x8000 for broadcast */ # UINT8 ClientIpAddress[4]; /* ciaddr: 0 */ # UINT8 YourIpAddress[4]; /* yiaddr: 0 */ # UINT8 ServerIpAddress[4]; /* siaddr: 0 */ # UINT8 RelayAgentIpAddress[4]; /* giaddr: 0 */ # UINT8 ClientHardwareAddress[16]; /* chaddr: 6 byte ethernet MAC address */ # UINT8 ServerName[64]; /* sname: 0 */ # UINT8 BootFileName[128]; /* file: 0 */ # UINT8 MagicCookie[4]; /* 99 130 83 99 */ # /* 0x63 0x82 0x53 0x63 */ # /* options -- hard code ours */ # # UINT8 MessageTypeCode; /* 53 */ # UINT8 MessageTypeLength; /* 1 */ # UINT8 MessageType; /* 1 for DISCOVER */ # UINT8 End; /* 255 */ # } DHCP; # # tuple of 244 zeros # (struct.pack_into would be good here, but requires Python 2.5) sendData = [0] * 244 transactionID = os.urandom(4) macAddress = MyDistro.GetMacAddress() # Opcode = 1 # HardwareAddressType = 1 (ethernet/MAC) # HardwareAddressLength = 6 (ethernet/MAC/48 bits) for a in range(0, 3): sendData[a] = [1, 1, 6][a] # fill in transaction id (random number to ensure response matches request) for a in range(0, 4): sendData[4 + a] = Ord(transactionID[a]) LogIfVerbose("BuildDhcpRequest: transactionId:%s,%04X" % ( self.HexDump2(transactionID), self.UnpackBigEndian(sendData, 4, 4))) # fill in ClientHardwareAddress for a in range(0, 6): sendData[0x1C + a] = Ord(macAddress[a]) # DHCP Magic Cookie: 99, 130, 83, 99 # MessageTypeCode = 53 DHCP Message Type # MessageTypeLength = 1 # MessageType = DHCPDISCOVER # End = 255 DHCP_END for a in range(0, 8): sendData[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a] return array.array("B", sendData) def IntegerToIpAddressV4String(self, a): """ Build DHCP request string. """ return "%u.%u.%u.%u" % ((a >> 24) & 0xFF, (a >> 16) & 0xFF, (a >> 8) & 0xFF, a & 0xFF) def RouteAdd(self, net, mask, gateway): """ Add specified route using /sbin/route add -net. """ net = self.IntegerToIpAddressV4String(net) mask = self.IntegerToIpAddressV4String(mask) gateway = self.IntegerToIpAddressV4String(gateway) Log("Route add: net={0}, mask={1}, gateway={2}".format(net, mask, gateway)) MyDistro.routeAdd(net, mask, gateway) def SetDefaultGateway(self, gateway): """ Set default gateway """ gateway = self.IntegerToIpAddressV4String(gateway) Log("Set default gateway: {0}".format(gateway)) MyDistro.setDefaultGateway(gateway) def HandleDhcpResponse(self, sendData, receiveBuffer): """ Parse DHCP response: Set default gateway. Set default routes. Retrieve endpoint server. Returns endpoint server or None on error. """ LogIfVerbose("HandleDhcpResponse") bytesReceived = len(receiveBuffer) if bytesReceived < 0xF6: Error("HandleDhcpResponse: Too few bytes received " + str(bytesReceived)) return None LogIfVerbose("BytesReceived: " + hex(bytesReceived)) LogWithPrefixIfVerbose("DHCP response:", HexDump(receiveBuffer, bytesReceived)) # check transactionId, cookie, MAC address # cookie should never mismatch # transactionId and MAC address may mismatch if we see a response meant from another machine for offsets in [list(range(4, 4 + 4)), list(range(0x1C, 0x1C + 6)), list(range(0xEC, 0xEC + 4))]: for offset in offsets: sentByte = Ord(sendData[offset]) receivedByte = Ord(receiveBuffer[offset]) if sentByte != receivedByte: LogIfVerbose("HandleDhcpResponse: sent cookie:" + self.HexDump3(sendData, 0xEC, 4)) LogIfVerbose("HandleDhcpResponse: rcvd cookie:" + self.HexDump3(receiveBuffer, 0xEC, 4)) LogIfVerbose("HandleDhcpResponse: sent transactionID:" + self.HexDump3(sendData, 4, 4)) LogIfVerbose("HandleDhcpResponse: rcvd transactionID:" + self.HexDump3(receiveBuffer, 4, 4)) LogIfVerbose("HandleDhcpResponse: sent ClientHardwareAddress:" + self.HexDump3(sendData, 0x1C, 6)) LogIfVerbose( "HandleDhcpResponse: rcvd ClientHardwareAddress:" + self.HexDump3(receiveBuffer, 0x1C, 6)) LogIfVerbose("HandleDhcpResponse: transactionId, cookie, or MAC address mismatch") return None endpoint = None # # Walk all the returned options, parsing out what we need, ignoring the others. # We need the custom option 245 to find the the endpoint we talk to, # as well as, to handle some Linux DHCP client incompatibilities, # options 3 for default gateway and 249 for routes. And 255 is end. # i = 0xF0 # offset to first option while i < bytesReceived: option = Ord(receiveBuffer[i]) length = 0 if (i + 1) < bytesReceived: length = Ord(receiveBuffer[i + 1]) LogIfVerbose("DHCP option " + hex(option) + " at offset:" + hex(i) + " with length:" + hex(length)) if option == 255: LogIfVerbose("DHCP packet ended at offset " + hex(i)) break elif option == 249: # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx LogIfVerbose("Routes at offset:" + hex(i) + " with length:" + hex(length)) if length < 5: Error("Data too small for option " + str(option)) j = i + 2 while j < (i + length + 2): maskLengthBits = Ord(receiveBuffer[j]) maskLengthBytes = (((maskLengthBits + 7) & ~7) >> 3) mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - maskLengthBits)) j += 1 net = self.UnpackBigEndian(receiveBuffer, j, maskLengthBytes) net <<= (32 - maskLengthBytes * 8) net &= mask j += maskLengthBytes gateway = self.UnpackBigEndian(receiveBuffer, j, 4) j += 4 self.RouteAdd(net, mask, gateway) if j != (i + length + 2): Error("HandleDhcpResponse: Unable to parse routes") elif option == 3 or option == 245: if i + 5 < bytesReceived: if length != 4: Error("HandleDhcpResponse: Endpoint or Default Gateway not 4 bytes") return None gateway = self.UnpackBigEndian(receiveBuffer, i + 2, 4) IpAddress = self.IntegerToIpAddressV4String(gateway) if option == 3: self.SetDefaultGateway(gateway) name = "DefaultGateway" else: endpoint = IpAddress name = "Azure wire protocol endpoint" LogIfVerbose(name + ": " + IpAddress + " at " + hex(i)) else: Error("HandleDhcpResponse: Data too small for option " + str(option)) else: LogIfVerbose("Skipping DHCP option " + hex(option) + " at " + hex(i) + " with length " + hex(length)) i += length + 2 return endpoint def DoDhcpWork(self): """ Discover the wire server via DHCP option 245. And workaround incompatibility with Azure DHCP servers. """ ShortSleep = False # Sleep 1 second before retrying DHCP queries. ifname = None sleepDurations = [0, 10, 30, 60, 60] maxRetry = len(sleepDurations) lastTry = (maxRetry - 1) for retry in range(0, maxRetry): try: # Open DHCP port if iptables is enabled. Run("iptables -D INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) # We supress error logging on error. Run("iptables -I INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) # We supress error logging on error. strRetry = str(retry) prefix = "DoDhcpWork: try=" + strRetry LogIfVerbose(prefix) sendData = self.BuildDhcpRequest() LogWithPrefixIfVerbose("DHCP request:", HexDump(sendData, len(sendData))) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) missingDefaultRoute = True try: if DistInfo()[0] == 'FreeBSD': missingDefaultRoute = True else: routes = RunGetOutput("route -n")[1] for line in routes.split('\n'): if line.startswith("0.0.0.0 ") or line.startswith("default "): missingDefaultRoute = False except: pass if missingDefaultRoute: # This is required because sending after binding to 0.0.0.0 fails with # network unreachable when the default gateway is not set up. ifname = MyDistro.GetInterfaceName() Log("DoDhcpWork: Missing default route - adding broadcast route for DHCP.") if DistInfo()[0] == 'FreeBSD': Run("route add -net 255.255.255.255 -iface " + ifname, chk_err=False) else: Run("route add 255.255.255.255 dev " + ifname, chk_err=False) if MyDistro.isDHCPEnabled(): MyDistro.stopDHCP() sock.bind(("0.0.0.0", 68)) sock.sendto(sendData, ("", 67)) sock.settimeout(10) Log("DoDhcpWork: Setting socket.timeout=10, entering recv") receiveBuffer = sock.recv(1024) endpoint = self.HandleDhcpResponse(sendData, receiveBuffer) if endpoint == None: LogIfVerbose("DoDhcpWork: No endpoint found") if endpoint != None or retry == lastTry: if endpoint != None: self.SendData = sendData self.DhcpResponse = receiveBuffer if retry == lastTry: LogIfVerbose("DoDhcpWork: try=" + strRetry) return endpoint sleepDuration = [sleepDurations[retry % len(sleepDurations)], 1][ShortSleep] LogIfVerbose("DoDhcpWork: sleep=" + str(sleepDuration)) time.sleep(sleepDuration) except Exception as e: ErrorWithPrefix(prefix, str(e)) ErrorWithPrefix(prefix, traceback.format_exc()) finally: sock.close() if missingDefaultRoute: # We added this route - delete it Log("DoDhcpWork: Removing broadcast route for DHCP.") if DistInfo()[0] == 'FreeBSD': Run("route del -net 255.255.255.255 -iface " + ifname, chk_err=False) else: Run("route del 255.255.255.255 dev " + ifname, chk_err=False) # We supress error logging on error. if MyDistro.isDHCPEnabled(): MyDistro.startDHCP() return None def UpdateAndPublishHostName(self, name): """ Set hostname locally and publish to iDNS """ Log("Setting host name: " + name) MyDistro.publishHostname(name) ethernetInterface = MyDistro.GetInterfaceName() MyDistro.RestartInterface(ethernetInterface) self.RestoreRoutes() def RestoreRoutes(self): """ If there is a DHCP response, then call HandleDhcpResponse. """ if self.SendData != None and self.DhcpResponse != None: self.HandleDhcpResponse(self.SendData, self.DhcpResponse) def UpdateGoalState(self): """ Retreive goal state information from endpoint server. Parse xml and initialize Agent.GoalState object. Return object or None on error. """ goalStateXml = None maxRetry = 9 log = NoLog for retry in range(1, maxRetry + 1): strRetry = str(retry) log("retry UpdateGoalState,retry=" + strRetry) goalStateXml = self.HttpGetWithHeaders("/machine/?comp=goalstate") if goalStateXml != None: break log = Log time.sleep(retry) if not goalStateXml: Error("UpdateGoalState failed.") return Log("Retrieved GoalState from Azure Fabric.") self.GoalState = GoalState(self).Parse(goalStateXml) return self.GoalState def ReportReady(self): """ Send health report 'Ready' to server. This signals the fabric that our provosion is completed, and the host is ready for operation. """ counter = (self.HealthReportCounter + 1) % 1000000 self.HealthReportCounter = counter healthReport = ( "" + self.GoalState.Incarnation + "" + self.GoalState.ContainerId + "" + self.GoalState.RoleInstanceId + "Ready") a = self.HttpPostWithHeaders("/machine?comp=health", healthReport) if a != None: return a.getheader("x-ms-latest-goal-state-incarnation-number") return None def ReportNotReady(self, status, desc): """ Send health report 'Provisioning' to server. This signals the fabric that our provosion is starting. """ healthReport = ( "" + self.GoalState.Incarnation + "" + self.GoalState.ContainerId + "" + self.GoalState.RoleInstanceId + "NotReady" + "
" + status + "" + desc + "
" + "
") a = self.HttpPostWithHeaders("/machine?comp=health", healthReport) if a != None: return a.getheader("x-ms-latest-goal-state-incarnation-number") return None def ReportRoleProperties(self, thumbprint): """ Send roleProperties and thumbprint to server. """ roleProperties = ("" + "" + self.GoalState.ContainerId + "" + "" + "" + self.GoalState.RoleInstanceId + "" + "" + "") a = self.HttpPostWithHeaders("/machine?comp=roleProperties", roleProperties) Log("Posted Role Properties. CertificateThumbprint=" + thumbprint) return a def LoadBalancerProbeServer_Shutdown(self): """ Shutdown the LoadBalancerProbeServer. """ if self.LoadBalancerProbeServer != None: self.LoadBalancerProbeServer.shutdown() self.LoadBalancerProbeServer = None def GenerateTransportCert(self): """ Create ssl certificate for https communication with endpoint server. """ Run( Openssl + " req -x509 -nodes -subj /CN=LinuxTransport -days 32768 -newkey rsa:2048 -keyout TransportPrivate.pem -out TransportCert.pem") cert = "" for line in GetFileContents("TransportCert.pem").split('\n'): if not "CERTIFICATE" in line: cert += line.rstrip() return cert def DoVmmStartup(self): """ Spawn the VMM startup script. """ Log("Starting Microsoft System Center VMM Initialization Process") pid = subprocess.Popen( ["/bin/bash", "/mnt/cdrom/secure/" + VMM_STARTUP_SCRIPT_NAME, "-p /mnt/cdrom/secure/ "]).pid time.sleep(5) sys.exit(0) def TryUnloadAtapiix(self): """ If global modloaded is True, then we loaded the ata_piix kernel module, unload it. """ if modloaded: Run("rmmod ata_piix.ko", chk_err=False) Log("Unloaded ata_piix.ko driver for ATAPI CD-ROM") def TryLoadAtapiix(self): """ Load the ata_piix kernel module if it exists. If successful, set global modloaded to True. If unable to load module leave modloaded False. """ global modloaded modloaded = False retcode, krn = RunGetOutput('uname -r') krn_pth = '/lib/modules/' + krn.strip('\n') + '/kernel/drivers/ata/ata_piix.ko' if Run("lsmod | grep ata_piix", chk_err=False) == 0: Log("Module " + krn_pth + " driver for ATAPI CD-ROM is already present.") return 0 if retcode: Error("Unable to provision: Failed to call uname -r") return "Unable to provision: Failed to call uname" if os.path.isfile(krn_pth): retcode, output = RunGetOutput("insmod " + krn_pth, chk_err=False) else: Log("Module " + krn_pth + " driver for ATAPI CD-ROM does not exist.") return 1 if retcode != 0: Error('Error calling insmod for ' + krn_pth + ' driver for ATAPI CD-ROM') return retcode time.sleep(1) # check 3 times if the mod is loaded for i in range(3): if Run('lsmod | grep ata_piix'): continue else: modloaded = True break if not modloaded: Error('Unable to load ' + krn_pth + ' driver for ATAPI CD-ROM') return 1 Log("Loaded " + krn_pth + " driver for ATAPI CD-ROM") # we have succeeded loading the ata_piix mod if it can be done. def SearchForVMMStartup(self): """ Search for a DVD/CDROM containing VMM's VMM_CONFIG_FILE_NAME. Call TryLoadAtapiix in case we must load the ata_piix module first. If VMM_CONFIG_FILE_NAME is found, call DoVmmStartup. Else, return to Azure Provisioning process. """ self.TryLoadAtapiix() if os.path.exists('/mnt/cdrom/secure') == False: CreateDir("/mnt/cdrom/secure", "root", 0o700) mounted = False for dvds in [re.match(r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9]?)', x) for x in os.listdir('/dev/')]: if dvds == None: continue dvd = '/dev/' + dvds.group(0) if Run("LC_ALL=C fdisk -l " + dvd + " | grep Disk", chk_err=False): continue # Not mountable else: for retry in range(1, 6): retcode, output = RunGetOutput("mount -v " + dvd + " /mnt/cdrom/secure") Log(output[:-1]) if retcode == 0: Log("mount succeeded on attempt #" + str(retry)) mounted = True break if 'is already mounted on /mnt/cdrom/secure' in output: Log("Device " + dvd + " is already mounted on /mnt/cdrom/secure." + str(retry)) mounted = True break Log("mount failed on attempt #" + str(retry)) Log("mount loop sleeping 5...") time.sleep(5) if not mounted: # unable to mount continue if not os.path.isfile("/mnt/cdrom/secure/" + VMM_CONFIG_FILE_NAME): # nope - mount the next drive if mounted: Run("umount " + dvd, chk_err=False) mounted = False continue else: # it is the vmm startup self.DoVmmStartup() Log("VMM Init script not found. Provisioning for Azure") return def Provision(self): """ Responible for: Regenerate ssh keys, Mount, read, and parse ovfenv.xml from provisioning dvd rom Process the ovfenv.xml info Call ReportRoleProperties If configured, delete root password. Return None on success, error string on error. """ enabled = Config.get("Provisioning.Enabled") if enabled != None and enabled.lower().startswith("n"): return Log("Provisioning image started.") type = Config.get("Provisioning.SshHostKeyPairType") if type == None: type = "rsa" regenerateKeys = Config.get("Provisioning.RegenerateSshHostKeyPair") if regenerateKeys == None or regenerateKeys.lower().startswith("y"): Run("rm -f /etc/ssh/ssh_host_*key*") Run("ssh-keygen -N '' -t " + type + " -f /etc/ssh/ssh_host_" + type + "_key") MyDistro.restartSshService() # SetFileContents(LibDir + "/provisioned", "") dvd = None for dvds in [re.match(r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9]?)', x) for x in os.listdir('/dev/')]: if dvds == None: continue dvd = '/dev/' + dvds.group(0) if dvd == None: # No DVD device detected Error("No DVD device detected, unable to provision.") return "No DVD device detected, unable to provision." if MyDistro.mediaHasFilesystem(dvd) is False: out = MyDistro.load_ata_piix() if out: return out for i in range(10): # we may have to wait if os.path.exists(dvd): break Log("Waiting for DVD - sleeping 1 - " + str(i + 1) + " try...") time.sleep(1) if os.path.exists('/mnt/cdrom/secure') == False: CreateDir("/mnt/cdrom/secure", "root", 0o700) # begin mount loop - 5 tries - 5 sec wait between for retry in range(1, 6): location = '/mnt/cdrom/secure' retcode, output = MyDistro.mountDVD(dvd, location) Log(output[:-1]) if retcode == 0: Log("mount succeeded on attempt #" + str(retry)) break if 'is already mounted on /mnt/cdrom/secure' in output: Log("Device " + dvd + " is already mounted on /mnt/cdrom/secure." + str(retry)) break Log("mount failed on attempt #" + str(retry)) Log("mount loop sleeping 5...") time.sleep(5) if not os.path.isfile("/mnt/cdrom/secure/ovf-env.xml"): Error("Unable to provision: Missing ovf-env.xml on DVD.") return "Failed to retrieve provisioning data (0x02)." ovfxml = (GetFileContents(u"/mnt/cdrom/secure/ovf-env.xml", asbin=False)) # use unicode here to ensure correct codec gets used. if ord(ovfxml[0]) > 128 and ord(ovfxml[1]) > 128 and ord(ovfxml[2]) > 128: ovfxml = ovfxml[ 3:] # BOM is not stripped. First three bytes are > 128 and not unicode chars so we ignore them. ovfxml = ovfxml.strip(chr(0x00)) # we may have NULLs. ovfxml = ovfxml[ovfxml.find('.*?<", "*<", ovfxml)) Run("umount " + dvd, chk_err=False) MyDistro.unload_ata_piix() error = None if ovfxml != None: Log("Provisioning image using OVF settings in the DVD.") ovfobj = OvfEnv().Parse(ovfxml) if ovfobj != None: error = ovfobj.Process() if error: Error("Provisioning image FAILED " + error) return ("Provisioning image FAILED " + error) Log("Ovf XML process finished") # This is done here because regenerated SSH host key pairs may be potentially overwritten when processing the ovfxml fingerprint = RunGetOutput("ssh-keygen -lf /etc/ssh/ssh_host_" + type + "_key.pub")[1].rstrip().split()[ 1].replace(':', '') self.ReportRoleProperties(fingerprint) delRootPass = Config.get("Provisioning.DeleteRootPassword") if delRootPass != None and delRootPass.lower().startswith("y"): MyDistro.deleteRootPassword() Log("Provisioning image completed.") return error def Run(self): """ Called by 'waagent -daemon.' Main loop to process the goal state. State is posted every 25 seconds when provisioning has been completed. Search for VMM enviroment, start VMM script if found. Perform DHCP and endpoint server discovery by calling DoDhcpWork(). Check wire protocol versions. Set SCSI timeout on root device. Call GenerateTransportCert() to create ssl certs for server communication. Call UpdateGoalState(). If not provisioned, call ReportNotReady("Provisioning", "Starting") Call Provision(), set global provisioned = True if successful. Call goalState.Process() Start LBProbeServer if indicated in waagent.conf. Start the StateConsumer if indicated in waagent.conf. ReportReady if provisioning is complete. If provisioning failed, call ReportNotReady("ProvisioningFailed", provisionError) """ SetFileContents("/var/run/waagent.pid", str(os.getpid()) + "\n") reportHandlerStatusCount = 0 # Determine if we are in VMM. Spawn VMM_STARTUP_SCRIPT_NAME if found. self.SearchForVMMStartup() ipv4 = '' while ipv4 == '' or ipv4 == '0.0.0.0': ipv4 = MyDistro.GetIpv4Address() if ipv4 == '' or ipv4 == '0.0.0.0': Log("Waiting for network.") time.sleep(10) Log("IPv4 address: " + ipv4) mac = '' mac = MyDistro.GetMacAddress() if len(mac) > 0: Log("MAC address: " + ":".join(["%02X" % Ord(a) for a in mac])) # Consume Entropy in ACPI table provided by Hyper-V try: SetFileContents("/dev/random", GetFileContents("/sys/firmware/acpi/tables/OEM0")) except: pass Log("Probing for Azure environment.") self.Endpoint = self.DoDhcpWork() while self.Endpoint == None: Log("Retry environment detection in 60 seconds") time.sleep(60) self.Endpoint = self.DoDhcpWork() Log("Discovered Azure endpoint: " + self.Endpoint) if not self.CheckVersions(): Error("Agent.CheckVersions failed") sys.exit(1) self.EnvMonitor = EnvMonitor() # Set SCSI timeout on SCSI disks MyDistro.initScsiDiskTimeout() global provisioned global provisionError global Openssl Openssl = Config.get("OS.OpensslPath") if Openssl == None: Openssl = "openssl" self.TransportCert = self.GenerateTransportCert() eventMonitor = None incarnation = None # goalStateIncarnationFromHealthReport currentPort = None # loadBalancerProbePort goalState = None # self.GoalState, instance of GoalState provisioned = os.path.exists(LibDir + "/provisioned") program = Config.get("Role.StateConsumer") provisionError = None lbProbeResponder = True lbProbeResponderNo = Config.no("LBProbeResponder") if lbProbeResponderNo: lbProbeResponder = False try: updateRdmaDriverConfigured = Config.yes("OS.UpdateRdmaDriver") updateRdmaRepository = Config.get("OS.RdmaRepository") if (updateRdmaDriverConfigured): MyDistro.rdmaUpdate(updateRdmaRepository) else: Log("OS.UpdateRdmaDriver configured to " + str( updateRdmaDriverConfigured) + " so skip the rdma update.") checkRdmaDriverConfigured = Config.yes("OS.CheckRdmaDriver") if (checkRdmaDriverConfigured): checkRdmaResult = MyDistro.checkRDMA() Log("Rdma check result is " + str(checkRdmaResult)) else: Log("OS.CheckRdmaDriver configured to " + str(checkRdmaDriverConfigured) + " so skip the rdma check.") except Exception as e: errMsg = 'check or update Rdma driver failed with error: %s, stack trace: %s' % ( str(e), traceback.format_exc()) Error(errMsg) while True: if (goalState == None) or (incarnation == None) or (goalState.Incarnation != incarnation): try: goalState = self.UpdateGoalState() except HttpResourceGoneError as e: Warn("Incarnation is out of date:{0}".format(e)) incarnation = None continue if goalState == None: Warn("Failed to fetch goalstate") continue if provisioned == False: self.ReportNotReady("Provisioning", "Starting") goalState.Process() if provisioned == False: provisionError = self.Provision() if provisionError == None: provisioned = True SetFileContents(LibDir + "/provisioned", "") lastCtime = "NOTFIND" try: walaConfigFile = MyDistro.getConfigurationPath() lastCtime = time.ctime(os.path.getctime(walaConfigFile)) except: pass # Get Ctime of wala config, can help identify the base image of this VM AddExtensionEvent(name="WALA", op=WALAEventOperation.Provision, isSuccess=True, message="WALA Config Ctime:" + lastCtime) executeCustomData = Config.get("Provisioning.ExecuteCustomData") if executeCustomData != None and executeCustomData.lower().startswith("y"): if os.path.exists(LibDir + '/CustomData'): Run('chmod +x ' + LibDir + '/CustomData') Run(LibDir + '/CustomData') else: Error(LibDir + '/CustomData does not exist.') # # only one port supported # restart server if new port is different than old port # stop server if no longer a port # goalPort = goalState.LoadBalancerProbePort if currentPort != goalPort: try: self.LoadBalancerProbeServer_Shutdown() currentPort = goalPort if currentPort != None and lbProbeResponder == True: self.LoadBalancerProbeServer = LoadBalancerProbeServer(currentPort) if self.LoadBalancerProbeServer == None: lbProbeResponder = False Log("Unable to create LBProbeResponder.") except Exception as e: Error("Failed to launch LBProbeResponder: {0}".format(e)) currentPort = None # Report SSH key fingerprint type = Config.get("Provisioning.SshHostKeyPairType") if type == None: type = "rsa" host_key_path = "/etc/ssh/ssh_host_" + type + "_key.pub" if (MyDistro.waitForSshHostKey(host_key_path)): fingerprint = \ RunGetOutput("ssh-keygen -lf /etc/ssh/ssh_host_" + type + "_key.pub")[1].rstrip().split()[ 1].replace(':', '') self.ReportRoleProperties(fingerprint) if program != None and DiskActivated == True: try: Children.append(subprocess.Popen([program, "Ready"])) except OSError as e: ErrorWithPrefix('SharedConfig.Parse', 'Exception: ' + str(e) + ' occured launching ' + program) program = None sleepToReduceAccessDenied = 3 time.sleep(sleepToReduceAccessDenied) if provisionError != None: incarnation = self.ReportNotReady("ProvisioningFailed", provisionError) else: incarnation = self.ReportReady() # Process our extensions. if goalState.ExtensionsConfig == None and goalState.ExtensionsConfigXml != None: reportHandlerStatusCount = 0 # Reset count when new goal state comes goalState.ExtensionsConfig = ExtensionsConfig().Parse(goalState.ExtensionsConfigXml) # report the status/heartbeat results of extension processing if goalState.ExtensionsConfig != None: ret = goalState.ExtensionsConfig.ReportHandlerStatus() if ret != 0: Error("Failed to report handler status") elif reportHandlerStatusCount % 1000 == 0: # Agent report handler status every 25 seconds. Reduce the log entries by adding a count Log("Successfully reported handler status") reportHandlerStatusCount += 1 global LinuxDistro if LinuxDistro == "redhat": DoInstallRHUIRPM() if not eventMonitor: eventMonitor = WALAEventMonitor(self.HttpPostWithHeaders) eventMonitor.StartEventsLoop() time.sleep(25 - sleepToReduceAccessDenied) WaagentLogrotate = """\ /var/log/waagent.log { monthly rotate 6 notifempty missingok } """ def GetMountPoint(mountlist, device): """ Example of mountlist: /dev/sda1 on / type ext4 (rw) proc on /proc type proc (rw) sysfs on /sys type sysfs (rw) devpts on /dev/pts type devpts (rw,gid=5,mode=620) tmpfs on /dev/shm type tmpfs (rw,rootcontext="system_u:object_r:tmpfs_t:s0") none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw) /dev/sdb1 on /mnt/resource type ext4 (rw) """ if (mountlist and device): for entry in mountlist.split('\n'): if (re.search(device, entry)): tokens = entry.split() # Return the 3rd column of this line return tokens[2] if len(tokens) > 2 else None return None def FindInLinuxKernelCmdline(option): """ Return match object if 'option' is present in the kernel boot options of the grub configuration. """ m = None matchs = r'^.*?' + MyDistro.grubKernelBootOptionsLine + r'.*?' + option + r'.*$' try: m = FindStringInFile(MyDistro.grubKernelBootOptionsFile, matchs) except IOError as e: Error( 'FindInLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str(e)) return m def AppendToLinuxKernelCmdline(option): """ Add 'option' to the kernel boot options of the grub configuration. """ if not FindInLinuxKernelCmdline(option): src = r'^(.*?' + MyDistro.grubKernelBootOptionsLine + r')(.*?)("?)$' rep = r'\1\2 ' + option + r'\3' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile, src, rep) except IOError as e: Error( 'AppendToLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str( e)) return 1 Run("update-grub", chk_err=False) return 0 def RemoveFromLinuxKernelCmdline(option): """ Remove 'option' to the kernel boot options of the grub configuration. """ if FindInLinuxKernelCmdline(option): src = r'^(.*?' + MyDistro.grubKernelBootOptionsLine + r'.*?)(' + option + r')(.*?)("?)$' rep = r'\1\3\4' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile, src, rep) except IOError as e: Error( 'RemoveFromLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str( e)) return 1 Run("update-grub", chk_err=False) return 0 def FindStringInFile(fname, matchs): """ Return match object if found in file. """ try: ms = re.compile(matchs) for l in (open(fname, 'r')).readlines(): m = re.search(ms, l) if m: return m except: raise return None def ReplaceStringInFile(fname, src, repl): """ Replace 'src' with 'repl' in file. """ try: sr = re.compile(src) if FindStringInFile(fname, src): updated = '' for l in (open(fname, 'r')).readlines(): n = re.sub(sr, repl, l) updated += n ReplaceFileContentsAtomic(fname, updated) except: raise return def ApplyVNUMAWorkaround(): """ If kernel version has NUMA bug, add 'numa=off' to kernel boot options. """ VersionParts = platform.release().replace('-', '.').split('.') if int(VersionParts[0]) > 2: return if int(VersionParts[1]) > 6: return if int(VersionParts[2]) > 37: return if AppendToLinuxKernelCmdline("numa=off") == 0: Log("Your kernel version " + platform.release() + " has a NUMA-related bug: NUMA has been disabled.") else: "Error adding 'numa=off'. NUMA has not been disabled." def RevertVNUMAWorkaround(): """ Remove 'numa=off' from kernel boot options. """ if RemoveFromLinuxKernelCmdline("numa=off") == 0: Log('NUMA has been re-enabled') else: Log('NUMA has not been re-enabled') def Install(): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0o755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a)) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", MyDistro.waagent_conf_file) SetFileContents("/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split( '\n'))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") ApplyVNUMAWorkaround() return 0 def GetMyDistro(dist_class_name=''): """ Return MyDistro object. NOTE: Logging is not initialized at this point. """ if dist_class_name == '': if 'Linux' in platform.system(): Distro = DistInfo()[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() Distro = Distro.strip('"') Distro = Distro.strip(' ') dist_class_name = Distro + 'Distro' else: Distro = dist_class_name if dist_class_name not in globals(): msg = Distro + ' is not a supported distribution. Reverting to DefaultDistro to support scenarios in ' \ 'unknown/unsupported distribution.' print(msg) Log(msg) return DefaultDistro() # the distro class inside this module. Check the implementations of AbstractDistro return globals()[dist_class_name]() def DistInfo(fullname=0): if 'FreeBSD' in platform.system(): release = re.sub(r'\-.*\Z', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=fullname)) distinfo[0] = distinfo[0].strip() # remove trailing whitespace in distro name if not distinfo[0]: distinfo = dist_info_SLES15() if not distinfo[0]: distinfo = dist_info_opensuse15() return distinfo else: return platform.dist() def dist_info_SLES15(): os_release_filepath = "/etc/os-release" if not os.path.isfile(os_release_filepath): return ["","",""] info = open(os_release_filepath).readlines() found_name_sles = False found_id_sles = False version_id = "" for line in info: if "NAME=\"SLES\"" in line: found_name_sles = True if "ID=\"sles\"" in line: found_id_sles = True if "VERSION_ID" in line: match = re.match(r'VERSION_ID="([.0-9]+)"', line) if match: version_id = match.group(1) if found_name_sles and found_id_sles and version_id: return "SuSE", version_id, "suse" return ["","",""] def dist_info_opensuse15(): os_release_filepath = "/etc/os-release" if not os.path.isfile(os_release_filepath): return ["","",""] info = open(os_release_filepath).readlines() found_name_opensuse_leap = False found_id_opensuse_leap = False version_id = "" for line in info: if "NAME=\"openSUSE" in line and "Leap" in line: found_name_opensuse_leap = True if "ID=\"opensuse-leap\"" in line: found_id_opensuse_leap = True if "VERSION_ID" in line: match = re.match(r'VERSION_ID="([.0-9]+)"', line) if match: version_id = match.group(1) if found_name_opensuse_leap and found_id_opensuse_leap and version_id: return "SuSE", version_id, "suse" return ["","",""] def PackagedInstall(buildroot): """ Called from setup.py for use by RPM. Generic implementation Creates directories and files /etc/waagent.conf, /etc/init.d/waagent, /usr/sbin/waagent, /etc/logrotate.d/waagent, /etc/sudoers.d/waagent under buildroot. Copies generated files waagent.conf, into place and exits. """ MyDistro = GetMyDistro() if MyDistro == None: sys.exit(1) MyDistro.packagedInstall(buildroot) def LibraryInstall(buildroot): pass def Uninstall(): """ Uninstall the agent service. Copy RulesFiles back to original locations. Delete agent-related files. Call RevertVNUMAWorkaround(). """ SwitchCwd() for a in RulesFiles: if os.path.isfile(GetLastPathElement(a)): try: shutil.move(GetLastPathElement(a), a) Warn("Moved " + LibDir + "/" + GetLastPathElement(a) + " -> " + a) except: pass MyDistro.unregisterAgentService() MyDistro.uninstallDeleteFiles() RevertVNUMAWorkaround() return 0 def Deprovision(force, deluser): """ Remove user accounts created by provisioning. Disables root password if Provisioning.DeleteRootPassword = 'y' Stop agent service. Remove SSH host keys if they were generated by the provision. Set hostname to 'localhost.localdomain'. Delete cached system configuration files in /var/lib and /var/lib/waagent. """ # Append blank line at the end of file, so the ctime of this file is changed every time Run("echo ''>>" + MyDistro.getConfigurationPath()) SwitchCwd() ovfxml = GetFileContents(LibDir + "/ovf-env.xml") ovfobj = None if ovfxml != None: ovfobj = OvfEnv().Parse(ovfxml, True) print("WARNING! The waagent service will be stopped.") print("WARNING! All SSH host key pairs will be deleted.") print("WARNING! Cached DHCP leases will be deleted.") MyDistro.deprovisionWarnUser() delRootPass = Config.get("Provisioning.DeleteRootPassword") if delRootPass != None and delRootPass.lower().startswith("y"): print("WARNING! root password will be disabled. You will not be able to login as root.") if ovfobj != None and deluser == True: print("WARNING! " + ovfobj.UserName + " account and entire home directory will be deleted.") if force == False and not raw_input('Do you want to proceed (y/n)? ').startswith('y'): return 1 MyDistro.stopAgentService() # Remove SSH host keys regenerateKeys = Config.get("Provisioning.RegenerateSshHostKeyPair") if regenerateKeys == None or regenerateKeys.lower().startswith("y"): Run("rm -f /etc/ssh/ssh_host_*key*") # Remove root password if delRootPass != None and delRootPass.lower().startswith("y"): MyDistro.deleteRootPassword() # Remove distribution specific networking configuration MyDistro.publishHostname('localhost.localdomain') MyDistro.deprovisionDeleteFiles() if deluser == True: MyDistro.DeleteAccount(ovfobj.UserName) return 0 def SwitchCwd(): """ Switch to cwd to /var/lib/waagent. Create if not present. """ CreateDir(LibDir, "root", 0o700) os.chdir(LibDir) def Usage(): """ Print the arguments to waagent. """ print("usage: " + sys.argv[ 0] + " [-verbose] [-force] [-help|-install|-uninstall|-deprovision[+user]|-version|-serialconsole|-daemon]") return 0 def main(): """ Instantiate MyDistro, exit if distro class is not defined. Parse command-line arguments, exit with usage() on error. Instantiate ConfigurationProvider. Call appropriate non-daemon methods and exit. If daemon mode, enter Agent.Run() loop. """ if GuestAgentVersion == "": print("WARNING! This is a non-standard agent that does not include a valid version string.") if len(sys.argv) == 1: sys.exit(Usage()) LoggerInit('/var/log/waagent.log', '/dev/console') global LinuxDistro LinuxDistro = DistInfo()[0] global MyDistro MyDistro = GetMyDistro() if MyDistro == None: sys.exit(1) args = [] conf_file = None global force force = False for a in sys.argv[1:]: if re.match(r"^([-/]*)(help|usage|\?)", a): sys.exit(Usage()) elif re.match("^([-/]*)version", a): print(GuestAgentVersion + " running on " + LinuxDistro) sys.exit(0) elif re.match("^([-/]*)verbose", a): myLogger.verbose = True elif re.match("^([-/]*)force", a): force = True elif re.match("^(?:[-/]*)conf=.+", a): conf_file = re.match("^(?:[-/]*)conf=(.+)", a).groups()[0] elif re.match("^([-/]*)(setup|install)", a): sys.exit(MyDistro.Install()) elif re.match("^([-/]*)(uninstall)", a): sys.exit(Uninstall()) else: args.append(a) global Config Config = ConfigurationProvider(conf_file) logfile = Config.get("Logs.File") if logfile is not None: myLogger.file_path = logfile logconsole = Config.get("Logs.Console") if logconsole is not None and logconsole.lower().startswith("n"): myLogger.con_path = None verbose = Config.get("Logs.Verbose") if verbose != None and verbose.lower().startswith("y"): myLogger.verbose = True global daemon daemon = False for a in args: if re.match(r"^([-/]*)deprovision\+user", a): sys.exit(Deprovision(force, True)) elif re.match("^([-/]*)deprovision", a): sys.exit(Deprovision(force, False)) elif re.match("^([-/]*)daemon", a): daemon = True elif re.match("^([-/]*)serialconsole", a): AppendToLinuxKernelCmdline("console=ttyS0 earlyprintk=ttyS0") Log("Configured kernel to use ttyS0 as the boot console.") sys.exit(0) else: print("Invalid command line parameter:" + a) sys.exit(1) if daemon == False: sys.exit(Usage()) global modloaded modloaded = False while True: try: SwitchCwd() Log(GuestAgentLongName + " Version: " + GuestAgentVersion) if IsLinux(): Log("Linux Distribution Detected : " + LinuxDistro) global WaAgent WaAgent = Agent() WaAgent.Run() except Exception as e: Error(traceback.format_exc()) Error("Exception: " + str(e)) Log("Restart agent in 15 seconds") time.sleep(15) if __name__ == '__main__': main() ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/__init__.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """psutil is a cross-platform library for retrieving information on running processes and system utilization (CPU, memory, disks, network) in Python. """ from __future__ import division __author__ = "Giampaolo Rodola'" __version__ = "2.2.1" version_info = tuple([int(num) for num in __version__.split('.')]) __all__ = [ # exceptions "Error", "NoSuchProcess", "AccessDenied", "TimeoutExpired", # constants "version_info", "__version__", "STATUS_RUNNING", "STATUS_IDLE", "STATUS_SLEEPING", "STATUS_DISK_SLEEP", "STATUS_STOPPED", "STATUS_TRACING_STOP", "STATUS_ZOMBIE", "STATUS_DEAD", "STATUS_WAKING", "STATUS_LOCKED", "STATUS_WAITING", "STATUS_LOCKED", "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1", "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT", "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", "CONN_NONE", # classes "Process", "Popen", # functions "pid_exists", "pids", "process_iter", "wait_procs", # proc "virtual_memory", "swap_memory", # memory "cpu_times", "cpu_percent", "cpu_times_percent", "cpu_count", # cpu "net_io_counters", "net_connections", # network "disk_io_counters", "disk_partitions", "disk_usage", # disk "users", "boot_time", # others ] import collections import errno import functools import os import signal import subprocess import sys import time import warnings try: import pwd except ImportError: pwd = None from psutil._common import memoize from psutil._compat import callable, long from psutil._compat import PY3 as _PY3 from psutil._common import (deprecated_method as _deprecated_method, deprecated as _deprecated, sdiskio as _nt_sys_diskio, snetio as _nt_sys_netio) from psutil._common import (STATUS_RUNNING, # NOQA STATUS_SLEEPING, STATUS_DISK_SLEEP, STATUS_STOPPED, STATUS_TRACING_STOP, STATUS_ZOMBIE, STATUS_DEAD, STATUS_WAKING, STATUS_LOCKED, STATUS_IDLE, # bsd STATUS_WAITING, # bsd STATUS_LOCKED) # bsd from psutil._common import (CONN_ESTABLISHED, CONN_SYN_SENT, CONN_SYN_RECV, CONN_FIN_WAIT1, CONN_FIN_WAIT2, CONN_TIME_WAIT, CONN_CLOSE, CONN_CLOSE_WAIT, CONN_LAST_ACK, CONN_LISTEN, CONN_CLOSING, CONN_NONE) if sys.platform.startswith("linux"): import psutil._pslinux as _psplatform from psutil._pslinux import (phymem_buffers, # NOQA cached_phymem) from psutil._pslinux import (IOPRIO_CLASS_NONE, # NOQA IOPRIO_CLASS_RT, IOPRIO_CLASS_BE, IOPRIO_CLASS_IDLE) # Linux >= 2.6.36 if _psplatform.HAS_PRLIMIT: from _psutil_linux import (RLIM_INFINITY, # NOQA RLIMIT_AS, RLIMIT_CORE, RLIMIT_CPU, RLIMIT_DATA, RLIMIT_FSIZE, RLIMIT_LOCKS, RLIMIT_MEMLOCK, RLIMIT_NOFILE, RLIMIT_NPROC, RLIMIT_RSS, RLIMIT_STACK) # Kinda ugly but considerably faster than using hasattr() and # setattr() against the module object (we are at import time: # speed matters). import _psutil_linux try: RLIMIT_MSGQUEUE = _psutil_linux.RLIMIT_MSGQUEUE except AttributeError: pass try: RLIMIT_NICE = _psutil_linux.RLIMIT_NICE except AttributeError: pass try: RLIMIT_RTPRIO = _psutil_linux.RLIMIT_RTPRIO except AttributeError: pass try: RLIMIT_RTTIME = _psutil_linux.RLIMIT_RTTIME except AttributeError: pass try: RLIMIT_SIGPENDING = _psutil_linux.RLIMIT_SIGPENDING except AttributeError: pass del _psutil_linux elif sys.platform.startswith("win32"): import psutil._pswindows as _psplatform from _psutil_windows import (ABOVE_NORMAL_PRIORITY_CLASS, # NOQA BELOW_NORMAL_PRIORITY_CLASS, HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, NORMAL_PRIORITY_CLASS, REALTIME_PRIORITY_CLASS) from psutil._pswindows import CONN_DELETE_TCB # NOQA elif sys.platform.startswith("darwin"): import psutil._psosx as _psplatform elif sys.platform.startswith("freebsd"): import psutil._psbsd as _psplatform elif sys.platform.startswith("sunos"): import psutil._pssunos as _psplatform from psutil._pssunos import (CONN_IDLE, # NOQA CONN_BOUND) else: raise NotImplementedError('platform %s is not supported' % sys.platform) __all__.extend(_psplatform.__extra__all__) _TOTAL_PHYMEM = None _POSIX = os.name == 'posix' _WINDOWS = os.name == 'nt' _timer = getattr(time, 'monotonic', time.time) # Sanity check in case the user messed up with psutil installation # or did something weird with sys.path. In this case we might end # up importing a python module using a C extension module which # was compiled for a different version of psutil. # We want to prevent that by failing sooner rather than later. # See: https://github.com/giampaolo/psutil/issues/564 if (int(__version__.replace('.', '')) != getattr(_psplatform.cext, 'version', None)): msg = "version conflict: %r C extension module was built for another " \ "version of psutil (different than %s)" % (_psplatform.cext.__file__, __version__) raise ImportError(msg) # ===================================================================== # --- exceptions # ===================================================================== class Error(Exception): """Base exception class. All other psutil exceptions inherit from this one. """ class NoSuchProcess(Error): """Exception raised when a process with a certain PID doesn't or no longer exists (zombie). """ def __init__(self, pid, name=None, msg=None): Error.__init__(self) self.pid = pid self.name = name self.msg = msg if msg is None: if name: details = "(pid=%s, name=%s)" % (self.pid, repr(self.name)) else: details = "(pid=%s)" % self.pid self.msg = "process no longer exists " + details def __str__(self): return self.msg class AccessDenied(Error): """Exception raised when permission to perform an action is denied.""" def __init__(self, pid=None, name=None, msg=None): Error.__init__(self) self.pid = pid self.name = name self.msg = msg if msg is None: if (pid is not None) and (name is not None): self.msg = "(pid=%s, name=%s)" % (pid, repr(name)) elif (pid is not None): self.msg = "(pid=%s)" % self.pid else: self.msg = "" def __str__(self): return self.msg class TimeoutExpired(Error): """Raised on Process.wait(timeout) if timeout expires and process is still alive. """ def __init__(self, seconds, pid=None, name=None): Error.__init__(self) self.seconds = seconds self.pid = pid self.name = name self.msg = "timeout after %s seconds" % seconds if (pid is not None) and (name is not None): self.msg += " (pid=%s, name=%s)" % (pid, repr(name)) elif (pid is not None): self.msg += " (pid=%s)" % self.pid def __str__(self): return self.msg # push exception classes into platform specific module namespace _psplatform.NoSuchProcess = NoSuchProcess _psplatform.AccessDenied = AccessDenied _psplatform.TimeoutExpired = TimeoutExpired # ===================================================================== # --- Process class # ===================================================================== def _assert_pid_not_reused(fun): """Decorator which raises NoSuchProcess in case a process is no longer running or its PID has been reused. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): if not self.is_running(): raise NoSuchProcess(self.pid, self._name) return fun(self, *args, **kwargs) return wrapper class Process(object): """Represents an OS process with the given PID. If PID is omitted current process PID (os.getpid()) is used. Raise NoSuchProcess if PID does not exist. Note that most of the methods of this class do not make sure the PID of the process being queried has been reused over time. That means you might end up retrieving an information referring to another process in case the original one this instance refers to is gone in the meantime. The only exceptions for which process identity is pre-emptively checked and guaranteed are: - parent() - children() - nice() (set) - ionice() (set) - rlimit() (set) - cpu_affinity (set) - suspend() - resume() - send_signal() - terminate() - kill() To prevent this problem for all other methods you can: - use is_running() before querying the process - if you're continuously iterating over a set of Process instances use process_iter() which pre-emptively checks process identity for every yielded instance """ def __init__(self, pid=None): self._init(pid) def _init(self, pid, _ignore_nsp=False): if pid is None: pid = os.getpid() else: if not _PY3 and not isinstance(pid, (int, long)): raise TypeError('pid must be an integer (got %r)' % pid) if pid < 0: raise ValueError('pid must be a positive integer (got %s)' % pid) self._pid = pid self._name = None self._exe = None self._create_time = None self._gone = False self._hash = None # used for caching on Windows only (on POSIX ppid may change) self._ppid = None # platform-specific modules define an _psplatform.Process # implementation class self._proc = _psplatform.Process(pid) self._last_sys_cpu_times = None self._last_proc_cpu_times = None # cache creation time for later use in is_running() method try: self.create_time() except AccessDenied: # we should never get here as AFAIK we're able to get # process creation time on all platforms even as a # limited user pass except NoSuchProcess: if not _ignore_nsp: msg = 'no process found with pid %s' % pid raise NoSuchProcess(pid, None, msg) else: self._gone = True # This pair is supposed to indentify a Process instance # univocally over time (the PID alone is not enough as # it might refer to a process whose PID has been reused). # This will be used later in __eq__() and is_running(). self._ident = (self.pid, self._create_time) def __str__(self): try: pid = self.pid name = repr(self.name()) except NoSuchProcess: details = "(pid=%s (terminated))" % self.pid except AccessDenied: details = "(pid=%s)" % (self.pid) else: details = "(pid=%s, name=%s)" % (pid, name) return "%s.%s%s" % (self.__class__.__module__, self.__class__.__name__, details) def __repr__(self): return "<%s at %s>" % (self.__str__(), id(self)) def __eq__(self, other): # Test for equality with another Process object based # on PID and creation time. if not isinstance(other, Process): return NotImplemented return self._ident == other._ident def __ne__(self, other): return not self == other def __hash__(self): if self._hash is None: self._hash = hash(self._ident) return self._hash # --- utility methods def as_dict(self, attrs=None, ad_value=None): """Utility method returning process information as a hashable dictionary. If 'attrs' is specified it must be a list of strings reflecting available Process class' attribute names (e.g. ['cpu_times', 'name']) else all public (read only) attributes are assumed. 'ad_value' is the value which gets assigned in case AccessDenied exception is raised when retrieving that particular process information. """ excluded_names = set( ['send_signal', 'suspend', 'resume', 'terminate', 'kill', 'wait', 'is_running', 'as_dict', 'parent', 'children', 'rlimit']) retdict = dict() ls = set(attrs or [x for x in dir(self) if not x.startswith('get')]) for name in ls: if name.startswith('_'): continue if name.startswith('set_'): continue if name.startswith('get_'): msg = "%s() is deprecated; use %s() instead" % (name, name[4:]) warnings.warn(msg, category=DeprecationWarning, stacklevel=2) name = name[4:] if name in ls: continue if name == 'getcwd': msg = "getcwd() is deprecated; use cwd() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) name = 'cwd' if name in ls: continue if name in excluded_names: continue try: attr = getattr(self, name) if callable(attr): ret = attr() else: ret = attr except AccessDenied: ret = ad_value except NotImplementedError: # in case of not implemented functionality (may happen # on old or exotic systems) we want to crash only if # the user explicitly asked for that particular attr if attrs: raise continue retdict[name] = ret return retdict def parent(self): """Return the parent process as a Process object pre-emptively checking whether PID has been reused. If no parent is known return None. """ ppid = self.ppid() if ppid is not None: try: parent = Process(ppid) if parent.create_time() <= self.create_time(): return parent # ...else ppid has been reused by another process except NoSuchProcess: pass def is_running(self): """Return whether this process is running. It also checks if PID has been reused by another process in which case return False. """ if self._gone: return False try: # Checking if PID is alive is not enough as the PID might # have been reused by another process: we also want to # check process identity. # Process identity / uniqueness over time is greanted by # (PID + creation time) and that is verified in __eq__. return self == Process(self.pid) except NoSuchProcess: self._gone = True return False # --- actual API @property def pid(self): """The process PID.""" return self._pid def ppid(self): """The process parent PID. On Windows the return value is cached after first call. """ # On POSIX we don't want to cache the ppid as it may unexpectedly # change to 1 (init) in case this process turns into a zombie: # https://github.com/giampaolo/psutil/issues/321 # http://stackoverflow.com/questions/356722/ # XXX should we check creation time here rather than in # Process.parent()? if _POSIX: return self._proc.ppid() else: if self._ppid is None: self._ppid = self._proc.ppid() return self._ppid def name(self): """The process name. The return value is cached after first call.""" if self._name is None: name = self._proc.name() if _POSIX and len(name) >= 15: # On UNIX the name gets truncated to the first 15 characters. # If it matches the first part of the cmdline we return that # one instead because it's usually more explicative. # Examples are "gnome-keyring-d" vs. "gnome-keyring-daemon". try: cmdline = self.cmdline() except AccessDenied: pass else: if cmdline: extended_name = os.path.basename(cmdline[0]) if extended_name.startswith(name): name = extended_name self._proc._name = name self._name = name return self._name def exe(self): """The process executable as an absolute path. May also be an empty string. The return value is cached after first call. """ def guess_it(fallback): # try to guess exe from cmdline[0] in absence of a native # exe representation cmdline = self.cmdline() if cmdline and hasattr(os, 'access') and hasattr(os, 'X_OK'): exe = cmdline[0] # the possible exe # Attempt to guess only in case of an absolute path. # It is not safe otherwise as the process might have # changed cwd. if (os.path.isabs(exe) and os.path.isfile(exe) and os.access(exe, os.X_OK)): return exe if isinstance(fallback, AccessDenied): raise fallback return fallback if self._exe is None: try: exe = self._proc.exe() except AccessDenied as err: return guess_it(fallback=err) else: if not exe: # underlying implementation can legitimately return an # empty string; if that's the case we don't want to # raise AD while guessing from the cmdline try: exe = guess_it(fallback=exe) except AccessDenied: pass self._exe = exe return self._exe def cmdline(self): """The command line this process has been called with.""" return self._proc.cmdline() def status(self): """The process current status as a STATUS_* constant.""" return self._proc.status() def username(self): """The name of the user that owns the process. On UNIX this is calculated by using *real* process uid. """ if _POSIX: if pwd is None: # might happen if python was installed from sources raise ImportError( "requires pwd module shipped with standard python") real_uid = self.uids().real try: return pwd.getpwuid(real_uid).pw_name except KeyError: # the uid can't be resolved by the system return str(real_uid) else: return self._proc.username() def create_time(self): """The process creation time as a floating point number expressed in seconds since the epoch, in UTC. The return value is cached after first call. """ if self._create_time is None: self._create_time = self._proc.create_time() return self._create_time def cwd(self): """Process current working directory as an absolute path.""" return self._proc.cwd() def nice(self, value=None): """Get or set process niceness (priority).""" if value is None: return self._proc.nice_get() else: if not self.is_running(): raise NoSuchProcess(self.pid, self._name) self._proc.nice_set(value) if _POSIX: def uids(self): """Return process UIDs as a (real, effective, saved) namedtuple. """ return self._proc.uids() def gids(self): """Return process GIDs as a (real, effective, saved) namedtuple. """ return self._proc.gids() def terminal(self): """The terminal associated with this process, if any, else None. """ return self._proc.terminal() def num_fds(self): """Return the number of file descriptors opened by this process (POSIX only). """ return self._proc.num_fds() # Linux, BSD and Windows only if hasattr(_psplatform.Process, "io_counters"): def io_counters(self): """Return process I/O statistics as a (read_count, write_count, read_bytes, write_bytes) namedtuple. Those are the number of read/write calls performed and the amount of bytes read and written by the process. """ return self._proc.io_counters() # Linux and Windows >= Vista only if hasattr(_psplatform.Process, "ionice_get"): def ionice(self, ioclass=None, value=None): """Get or set process I/O niceness (priority). On Linux 'ioclass' is one of the IOPRIO_CLASS_* constants. 'value' is a number which goes from 0 to 7. The higher the value, the lower the I/O priority of the process. On Windows only 'ioclass' is used and it can be set to 2 (normal), 1 (low) or 0 (very low). Available on Linux and Windows > Vista only. """ if ioclass is None: if value is not None: raise ValueError("'ioclass' must be specified") return self._proc.ionice_get() else: return self._proc.ionice_set(ioclass, value) # Linux only if hasattr(_psplatform.Process, "rlimit"): def rlimit(self, resource, limits=None): """Get or set process resource limits as a (soft, hard) tuple. 'resource' is one of the RLIMIT_* constants. 'limits' is supposed to be a (soft, hard) tuple. See "man prlimit" for further info. Available on Linux only. """ if limits is None: return self._proc.rlimit(resource) else: return self._proc.rlimit(resource, limits) # Windows, Linux and BSD only if hasattr(_psplatform.Process, "cpu_affinity_get"): def cpu_affinity(self, cpus=None): """Get or set process CPU affinity. If specified 'cpus' must be a list of CPUs for which you want to set the affinity (e.g. [0, 1]). (Windows, Linux and BSD only). """ if cpus is None: return self._proc.cpu_affinity_get() else: self._proc.cpu_affinity_set(cpus) if _WINDOWS: def num_handles(self): """Return the number of handles opened by this process (Windows only). """ return self._proc.num_handles() def num_ctx_switches(self): """Return the number of voluntary and involuntary context switches performed by this process. """ return self._proc.num_ctx_switches() def num_threads(self): """Return the number of threads used by this process.""" return self._proc.num_threads() def threads(self): """Return threads opened by process as a list of (id, user_time, system_time) namedtuples representing thread id and thread CPU times (user/system). """ return self._proc.threads() @_assert_pid_not_reused def children(self, recursive=False): """Return the children of this process as a list of Process instances, pre-emptively checking whether PID has been reused. If recursive is True return all the parent descendants. Example (A == this process): A ─┐ │ ├─ B (child) ─┐ │ └─ X (grandchild) ─┐ │ └─ Y (great grandchild) ├─ C (child) └─ D (child) >>> import psutil >>> p = psutil.Process() >>> p.children() B, C, D >>> p.children(recursive=True) B, X, Y, C, D Note that in the example above if process X disappears process Y won't be listed as the reference to process A is lost. """ if hasattr(_psplatform, 'ppid_map'): # Windows only: obtain a {pid:ppid, ...} dict for all running # processes in one shot (faster). ppid_map = _psplatform.ppid_map() else: ppid_map = None ret = [] if not recursive: if ppid_map is None: # 'slow' version, common to all platforms except Windows for p in process_iter(): try: if p.ppid() == self.pid: # if child happens to be older than its parent # (self) it means child's PID has been reused if self.create_time() <= p.create_time(): ret.append(p) except NoSuchProcess: pass else: # Windows only (faster) for pid, ppid in ppid_map.items(): if ppid == self.pid: try: child = Process(pid) # if child happens to be older than its parent # (self) it means child's PID has been reused if self.create_time() <= child.create_time(): ret.append(child) except NoSuchProcess: pass else: # construct a dict where 'values' are all the processes # having 'key' as their parent table = collections.defaultdict(list) if ppid_map is None: for p in process_iter(): try: table[p.ppid()].append(p) except NoSuchProcess: pass else: for pid, ppid in ppid_map.items(): try: p = Process(pid) table[ppid].append(p) except NoSuchProcess: pass # At this point we have a mapping table where table[self.pid] # are the current process' children. # Below, we look for all descendants recursively, similarly # to a recursive function call. checkpids = [self.pid] for pid in checkpids: for child in table[pid]: try: # if child happens to be older than its parent # (self) it means child's PID has been reused intime = self.create_time() <= child.create_time() except NoSuchProcess: pass else: if intime: ret.append(child) if child.pid not in checkpids: checkpids.append(child.pid) return ret def cpu_percent(self, interval=None): """Return a float representing the current process CPU utilization as a percentage. When interval is 0.0 or None (default) compares process times to system CPU times elapsed since last call, returning immediately (non-blocking). That means that the first time this is called it will return a meaningful 0.0 value. When interval is > 0.0 compares process times to system CPU times elapsed before and after the interval (blocking). In this case is recommended for accuracy that this function be called with at least 0.1 seconds between calls. Examples: >>> import psutil >>> p = psutil.Process(os.getpid()) >>> # blocking >>> p.cpu_percent(interval=1) 2.0 >>> # non-blocking (percentage since last call) >>> p.cpu_percent(interval=None) 2.9 >>> """ blocking = interval is not None and interval > 0.0 num_cpus = cpu_count() if _POSIX: timer = lambda: _timer() * num_cpus else: timer = lambda: sum(cpu_times()) if blocking: st1 = timer() pt1 = self._proc.cpu_times() time.sleep(interval) st2 = timer() pt2 = self._proc.cpu_times() else: st1 = self._last_sys_cpu_times pt1 = self._last_proc_cpu_times st2 = timer() pt2 = self._proc.cpu_times() if st1 is None or pt1 is None: self._last_sys_cpu_times = st2 self._last_proc_cpu_times = pt2 return 0.0 delta_proc = (pt2.user - pt1.user) + (pt2.system - pt1.system) delta_time = st2 - st1 # reset values for next call in case of interval == None self._last_sys_cpu_times = st2 self._last_proc_cpu_times = pt2 try: # The utilization split between all CPUs. # Note: a percentage > 100 is legitimate as it can result # from a process with multiple threads running on different # CPU cores, see: # http://stackoverflow.com/questions/1032357 # https://github.com/giampaolo/psutil/issues/474 overall_percent = ((delta_proc / delta_time) * 100) * num_cpus except ZeroDivisionError: # interval was too low return 0.0 else: return round(overall_percent, 1) def cpu_times(self): """Return a (user, system) namedtuple representing the accumulated process time, in seconds. This is the same as os.times() but per-process. """ return self._proc.cpu_times() def memory_info(self): """Return a tuple representing RSS (Resident Set Size) and VMS (Virtual Memory Size) in bytes. On UNIX RSS and VMS are the same values shown by 'ps'. On Windows RSS and VMS refer to "Mem Usage" and "VM Size" columns of taskmgr.exe. """ return self._proc.memory_info() def memory_info_ex(self): """Return a namedtuple with variable fields depending on the platform representing extended memory information about this process. All numbers are expressed in bytes. """ return self._proc.memory_info_ex() def memory_percent(self): """Compare physical system memory to process resident memory (RSS) and calculate process memory utilization as a percentage. """ rss = self._proc.memory_info()[0] # use cached value if available total_phymem = _TOTAL_PHYMEM or virtual_memory().total try: return (rss / float(total_phymem)) * 100 except ZeroDivisionError: return 0.0 def memory_maps(self, grouped=True): """Return process' mapped memory regions as a list of nameduples whose fields are variable depending on the platform. If 'grouped' is True the mapped regions with the same 'path' are grouped together and the different memory fields are summed. If 'grouped' is False every mapped region is shown as a single entity and the namedtuple will also include the mapped region's address space ('addr') and permission set ('perms'). """ it = self._proc.memory_maps() if grouped: d = {} for tupl in it: path = tupl[2] nums = tupl[3:] try: d[path] = map(lambda x, y: x + y, d[path], nums) except KeyError: d[path] = nums nt = _psplatform.pmmap_grouped return [nt(path, *d[path]) for path in d] # NOQA else: nt = _psplatform.pmmap_ext return [nt(*x) for x in it] def open_files(self): """Return files opened by process as a list of (path, fd) namedtuples including the absolute file name and file descriptor number. """ return self._proc.open_files() def connections(self, kind='inet'): """Return connections opened by process as a list of (fd, family, type, laddr, raddr, status) namedtuples. The 'kind' parameter filters for connections that match the following criteria: Kind Value Connections using inet IPv4 and IPv6 inet4 IPv4 inet6 IPv6 tcp TCP tcp4 TCP over IPv4 tcp6 TCP over IPv6 udp UDP udp4 UDP over IPv4 udp6 UDP over IPv6 unix UNIX socket (both UDP and TCP protocols) all the sum of all the possible families and protocols """ return self._proc.connections(kind) if _POSIX: def _send_signal(self, sig): # XXX: according to "man 2 kill" PID 0 has a special # meaning as it refers to <>, so should we prevent # it here? try: os.kill(self.pid, sig) except OSError as err: if err.errno == errno.ESRCH: self._gone = True raise NoSuchProcess(self.pid, self._name) if err.errno == errno.EPERM: raise AccessDenied(self.pid, self._name) raise @_assert_pid_not_reused def send_signal(self, sig): """Send a signal to process pre-emptively checking whether PID has been reused (see signal module constants) . On Windows only SIGTERM is valid and is treated as an alias for kill(). """ if _POSIX: self._send_signal(sig) else: if sig == signal.SIGTERM: self._proc.kill() else: raise ValueError("only SIGTERM is supported on Windows") @_assert_pid_not_reused def suspend(self): """Suspend process execution with SIGSTOP pre-emptively checking whether PID has been reused. On Windows this has the effect ot suspending all process threads. """ if _POSIX: self._send_signal(signal.SIGSTOP) else: self._proc.suspend() @_assert_pid_not_reused def resume(self): """Resume process execution with SIGCONT pre-emptively checking whether PID has been reused. On Windows this has the effect of resuming all process threads. """ if _POSIX: self._send_signal(signal.SIGCONT) else: self._proc.resume() @_assert_pid_not_reused def terminate(self): """Terminate the process with SIGTERM pre-emptively checking whether PID has been reused. On Windows this is an alias for kill(). """ if _POSIX: self._send_signal(signal.SIGTERM) else: self._proc.kill() @_assert_pid_not_reused def kill(self): """Kill the current process with SIGKILL pre-emptively checking whether PID has been reused. """ if _POSIX: self._send_signal(signal.SIGKILL) else: self._proc.kill() def wait(self, timeout=None): """Wait for process to terminate and, if process is a children of os.getpid(), also return its exit code, else None. If the process is already terminated immediately return None instead of raising NoSuchProcess. If timeout (in seconds) is specified and process is still alive raise TimeoutExpired. To wait for multiple Process(es) use psutil.wait_procs(). """ if timeout is not None and not timeout >= 0: raise ValueError("timeout must be a positive integer") return self._proc.wait(timeout) # --- deprecated APIs _locals = set(locals()) @_deprecated_method(replacement='children') def get_children(self): pass @_deprecated_method(replacement='connections') def get_connections(self): pass if "cpu_affinity" in _locals: @_deprecated_method(replacement='cpu_affinity') def get_cpu_affinity(self): pass @_deprecated_method(replacement='cpu_affinity') def set_cpu_affinity(self, cpus): pass @_deprecated_method(replacement='cpu_percent') def get_cpu_percent(self): pass @_deprecated_method(replacement='cpu_times') def get_cpu_times(self): pass @_deprecated_method(replacement='cwd') def getcwd(self): pass @_deprecated_method(replacement='memory_info_ex') def get_ext_memory_info(self): pass if "io_counters" in _locals: @_deprecated_method(replacement='io_counters') def get_io_counters(self): pass if "ionice" in _locals: @_deprecated_method(replacement='ionice') def get_ionice(self): pass @_deprecated_method(replacement='ionice') def set_ionice(self, ioclass, value=None): pass @_deprecated_method(replacement='memory_info') def get_memory_info(self): pass @_deprecated_method(replacement='memory_maps') def get_memory_maps(self): pass @_deprecated_method(replacement='memory_percent') def get_memory_percent(self): pass @_deprecated_method(replacement='nice') def get_nice(self): pass @_deprecated_method(replacement='num_ctx_switches') def get_num_ctx_switches(self): pass if 'num_fds' in _locals: @_deprecated_method(replacement='num_fds') def get_num_fds(self): pass if 'num_handles' in _locals: @_deprecated_method(replacement='num_handles') def get_num_handles(self): pass @_deprecated_method(replacement='num_threads') def get_num_threads(self): pass @_deprecated_method(replacement='open_files') def get_open_files(self): pass if "rlimit" in _locals: @_deprecated_method(replacement='rlimit') def get_rlimit(self): pass @_deprecated_method(replacement='rlimit') def set_rlimit(self, resource, limits): pass @_deprecated_method(replacement='threads') def get_threads(self): pass @_deprecated_method(replacement='nice') def set_nice(self, value): pass del _locals # ===================================================================== # --- Popen class # ===================================================================== class Popen(Process): """A more convenient interface to stdlib subprocess module. It starts a sub process and deals with it exactly as when using subprocess.Popen class but in addition also provides all the properties and methods of psutil.Process class as a unified interface: >>> import psutil >>> from subprocess import PIPE >>> p = psutil.Popen(["python", "-c", "print 'hi'"], stdout=PIPE) >>> p.name() 'python' >>> p.uids() user(real=1000, effective=1000, saved=1000) >>> p.username() 'giampaolo' >>> p.communicate() ('hi\n', None) >>> p.terminate() >>> p.wait(timeout=2) 0 >>> For method names common to both classes such as kill(), terminate() and wait(), psutil.Process implementation takes precedence. Unlike subprocess.Popen this class pre-emptively checks wheter PID has been reused on send_signal(), terminate() and kill() so that you don't accidentally terminate another process, fixing http://bugs.python.org/issue6973. For a complete documentation refer to: http://docs.python.org/library/subprocess.html """ def __init__(self, *args, **kwargs): # Explicitly avoid to raise NoSuchProcess in case the process # spawned by subprocess.Popen terminates too quickly, see: # https://github.com/giampaolo/psutil/issues/193 self.__subproc = subprocess.Popen(*args, **kwargs) self._init(self.__subproc.pid, _ignore_nsp=True) def __dir__(self): return sorted(set(dir(Popen) + dir(subprocess.Popen))) def __getattribute__(self, name): try: return object.__getattribute__(self, name) except AttributeError: try: return object.__getattribute__(self.__subproc, name) except AttributeError: raise AttributeError("%s instance has no attribute '%s'" % (self.__class__.__name__, name)) def wait(self, timeout=None): if self.__subproc.returncode is not None: return self.__subproc.returncode ret = super(Popen, self).wait(timeout) self.__subproc.returncode = ret return ret # ===================================================================== # --- system processes related functions # ===================================================================== def pids(): """Return a list of current running PIDs.""" return _psplatform.pids() def pid_exists(pid): """Return True if given PID exists in the current process list. This is faster than doing "pid in psutil.pids()" and should be preferred. """ if pid < 0: return False elif pid == 0 and _POSIX: # On POSIX we use os.kill() to determine PID existence. # According to "man 2 kill" PID 0 has a special meaning # though: it refers to <> and that is not we want # to do here. return pid in pids() else: return _psplatform.pid_exists(pid) _pmap = {} def process_iter(): """Return a generator yielding a Process instance for all running processes. Every new Process instance is only created once and then cached into an internal table which is updated every time this is used. Cached Process instances are checked for identity so that you're safe in case a PID has been reused by another process, in which case the cached instance is updated. The sorting order in which processes are yielded is based on their PIDs. """ def add(pid): proc = Process(pid) _pmap[proc.pid] = proc return proc def remove(pid): _pmap.pop(pid, None) a = set(pids()) b = set(_pmap.keys()) new_pids = a - b gone_pids = b - a for pid in gone_pids: remove(pid) for pid, proc in sorted(list(_pmap.items()) + list(dict.fromkeys(new_pids).items())): try: if proc is None: # new process yield add(pid) else: # use is_running() to check whether PID has been reused by # another process in which case yield a new Process instance if proc.is_running(): yield proc else: yield add(pid) except NoSuchProcess: remove(pid) except AccessDenied: # Process creation time can't be determined hence there's # no way to tell whether the pid of the cached process # has been reused. Just return the cached version. yield proc def wait_procs(procs, timeout=None, callback=None): """Convenience function which waits for a list of processes to terminate. Return a (gone, alive) tuple indicating which processes are gone and which ones are still alive. The gone ones will have a new 'returncode' attribute indicating process exit status (may be None). 'callback' is a function which gets called every time a process terminates (a Process instance is passed as callback argument). Function will return as soon as all processes terminate or when timeout occurs. Typical use case is: - send SIGTERM to a list of processes - give them some time to terminate - send SIGKILL to those ones which are still alive Example: >>> def on_terminate(proc): ... print("process {} terminated".format(proc)) ... >>> for p in procs: ... p.terminate() ... >>> gone, alive = wait_procs(procs, timeout=3, callback=on_terminate) >>> for p in alive: ... p.kill() """ def check_gone(proc, timeout): try: returncode = proc.wait(timeout=timeout) except TimeoutExpired: pass else: if returncode is not None or not proc.is_running(): proc.returncode = returncode gone.add(proc) if callback is not None: callback(proc) if timeout is not None and not timeout >= 0: msg = "timeout must be a positive integer, got %s" % timeout raise ValueError(msg) gone = set() alive = set(procs) if callback is not None and not callable(callback): raise TypeError("callback %r is not a callable" % callable) if timeout is not None: deadline = _timer() + timeout while alive: if timeout is not None and timeout <= 0: break for proc in alive: # Make sure that every complete iteration (all processes) # will last max 1 sec. # We do this because we don't want to wait too long on a # single process: in case it terminates too late other # processes may disappear in the meantime and their PID # reused. max_timeout = 1.0 / len(alive) if timeout is not None: timeout = min((deadline - _timer()), max_timeout) if timeout <= 0: break check_gone(proc, timeout) else: check_gone(proc, max_timeout) alive = alive - gone if alive: # Last attempt over processes survived so far. # timeout == 0 won't make this function wait any further. for proc in alive: check_gone(proc, 0) alive = alive - gone return (list(gone), list(alive)) # ===================================================================== # --- CPU related functions # ===================================================================== @memoize def cpu_count(logical=True): """Return the number of logical CPUs in the system (same as os.cpu_count() in Python 3.4). If logical is False return the number of physical cores only (hyper thread CPUs are excluded). Return None if undetermined. The return value is cached after first call. If desired cache can be cleared like this: >>> psutil.cpu_count.cache_clear() """ if logical: return _psplatform.cpu_count_logical() else: return _psplatform.cpu_count_physical() def cpu_times(percpu=False): """Return system-wide CPU times as a namedtuple. Every CPU time represents the seconds the CPU has spent in the given mode. The namedtuple's fields availability varies depending on the platform: - user - system - idle - nice (UNIX) - iowait (Linux) - irq (Linux, FreeBSD) - softirq (Linux) - steal (Linux >= 2.6.11) - guest (Linux >= 2.6.24) - guest_nice (Linux >= 3.2.0) When percpu is True return a list of nameduples for each CPU. First element of the list refers to first CPU, second element to second CPU and so on. The order of the list is consistent across calls. """ if not percpu: return _psplatform.cpu_times() else: return _psplatform.per_cpu_times() _last_cpu_times = cpu_times() _last_per_cpu_times = cpu_times(percpu=True) def cpu_percent(interval=None, percpu=False): """Return a float representing the current system-wide CPU utilization as a percentage. When interval is > 0.0 compares system CPU times elapsed before and after the interval (blocking). When interval is 0.0 or None compares system CPU times elapsed since last call or module import, returning immediately (non blocking). That means the first time this is called it will return a meaningless 0.0 value which you should ignore. In this case is recommended for accuracy that this function be called with at least 0.1 seconds between calls. When percpu is True returns a list of floats representing the utilization as a percentage for each CPU. First element of the list refers to first CPU, second element to second CPU and so on. The order of the list is consistent across calls. Examples: >>> # blocking, system-wide >>> psutil.cpu_percent(interval=1) 2.0 >>> >>> # blocking, per-cpu >>> psutil.cpu_percent(interval=1, percpu=True) [2.0, 1.0] >>> >>> # non-blocking (percentage since last call) >>> psutil.cpu_percent(interval=None) 2.9 >>> """ global _last_cpu_times global _last_per_cpu_times blocking = interval is not None and interval > 0.0 def calculate(t1, t2): t1_all = sum(t1) t1_busy = t1_all - t1.idle t2_all = sum(t2) t2_busy = t2_all - t2.idle # this usually indicates a float precision issue if t2_busy <= t1_busy: return 0.0 busy_delta = t2_busy - t1_busy all_delta = t2_all - t1_all busy_perc = (busy_delta / all_delta) * 100 return round(busy_perc, 1) # system-wide usage if not percpu: if blocking: t1 = cpu_times() time.sleep(interval) else: t1 = _last_cpu_times _last_cpu_times = cpu_times() return calculate(t1, _last_cpu_times) # per-cpu usage else: ret = [] if blocking: tot1 = cpu_times(percpu=True) time.sleep(interval) else: tot1 = _last_per_cpu_times _last_per_cpu_times = cpu_times(percpu=True) for t1, t2 in zip(tot1, _last_per_cpu_times): ret.append(calculate(t1, t2)) return ret # Use separate global vars for cpu_times_percent() so that it's # independent from cpu_percent() and they can both be used within # the same program. _last_cpu_times_2 = _last_cpu_times _last_per_cpu_times_2 = _last_per_cpu_times def cpu_times_percent(interval=None, percpu=False): """Same as cpu_percent() but provides utilization percentages for each specific CPU time as is returned by cpu_times(). For instance, on Linux we'll get: >>> cpu_times_percent() cpupercent(user=4.8, nice=0.0, system=4.8, idle=90.5, iowait=0.0, irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0) >>> interval and percpu arguments have the same meaning as in cpu_percent(). """ global _last_cpu_times_2 global _last_per_cpu_times_2 blocking = interval is not None and interval > 0.0 def calculate(t1, t2): nums = [] all_delta = sum(t2) - sum(t1) for field in t1._fields: field_delta = getattr(t2, field) - getattr(t1, field) try: field_perc = (100 * field_delta) / all_delta except ZeroDivisionError: field_perc = 0.0 field_perc = round(field_perc, 1) if _WINDOWS: # XXX # Work around: # https://github.com/giampaolo/psutil/issues/392 # CPU times are always supposed to increase over time # or at least remain the same and that's because time # cannot go backwards. # Surprisingly sometimes this might not be the case on # Windows where 'system' CPU time can be smaller # compared to the previous call, resulting in corrupted # percentages (< 0 or > 100). # I really don't know what to do about that except # forcing the value to 0 or 100. if field_perc > 100.0: field_perc = 100.0 elif field_perc < 0.0: field_perc = 0.0 nums.append(field_perc) return _psplatform.scputimes(*nums) # system-wide usage if not percpu: if blocking: t1 = cpu_times() time.sleep(interval) else: t1 = _last_cpu_times_2 _last_cpu_times_2 = cpu_times() return calculate(t1, _last_cpu_times_2) # per-cpu usage else: ret = [] if blocking: tot1 = cpu_times(percpu=True) time.sleep(interval) else: tot1 = _last_per_cpu_times_2 _last_per_cpu_times_2 = cpu_times(percpu=True) for t1, t2 in zip(tot1, _last_per_cpu_times_2): ret.append(calculate(t1, t2)) return ret # ===================================================================== # --- system memory related functions # ===================================================================== def virtual_memory(): """Return statistics about system memory usage as a namedtuple including the following fields, expressed in bytes: - total: total physical memory available. - available: the actual amount of available memory that can be given instantly to processes that request more memory in bytes; this is calculated by summing different memory values depending on the platform (e.g. free + buffers + cached on Linux) and it is supposed to be used to monitor actual memory usage in a cross platform fashion. - percent: the percentage usage calculated as (total - available) / total * 100 - used: memory used, calculated differently depending on the platform and designed for informational purposes only: OSX: active + inactive + wired BSD: active + wired + cached LINUX: total - free - free: memory not being used at all (zeroed) that is readily available; note that this doesn't reflect the actual memory available (use 'available' instead) Platform-specific fields: - active (UNIX): memory currently in use or very recently used, and so it is in RAM. - inactive (UNIX): memory that is marked as not used. - buffers (BSD, Linux): cache for things like file system metadata. - cached (BSD, OSX): cache for various things. - wired (OSX, BSD): memory that is marked to always stay in RAM. It is never moved to disk. - shared (BSD): memory that may be simultaneously accessed by multiple processes. The sum of 'used' and 'available' does not necessarily equal total. On Windows 'available' and 'free' are the same. """ global _TOTAL_PHYMEM ret = _psplatform.virtual_memory() # cached for later use in Process.memory_percent() _TOTAL_PHYMEM = ret.total return ret def swap_memory(): """Return system swap memory statistics as a namedtuple including the following fields: - total: total swap memory in bytes - used: used swap memory in bytes - free: free swap memory in bytes - percent: the percentage usage - sin: no. of bytes the system has swapped in from disk (cumulative) - sout: no. of bytes the system has swapped out from disk (cumulative) 'sin' and 'sout' on Windows are meaningless and always set to 0. """ return _psplatform.swap_memory() # ===================================================================== # --- disks/paritions related functions # ===================================================================== def disk_usage(path): """Return disk usage statistics about the given path as a namedtuple including total, used and free space expressed in bytes plus the percentage usage. """ return _psplatform.disk_usage(path) def disk_partitions(all=False): """Return mounted partitions as a list of (device, mountpoint, fstype, opts) namedtuple. 'opts' field is a raw string separated by commas indicating mount options which may vary depending on the platform. If "all" parameter is False return physical devices only and ignore all others. """ return _psplatform.disk_partitions(all) def disk_io_counters(perdisk=False): """Return system disk I/O statistics as a namedtuple including the following fields: - read_count: number of reads - write_count: number of writes - read_bytes: number of bytes read - write_bytes: number of bytes written - read_time: time spent reading from disk (in milliseconds) - write_time: time spent writing to disk (in milliseconds) If perdisk is True return the same information for every physical disk installed on the system as a dictionary with partition names as the keys and the namedutuple described above as the values. On recent Windows versions 'diskperf -y' command may need to be executed first otherwise this function won't find any disk. """ rawdict = _psplatform.disk_io_counters() if not rawdict: raise RuntimeError("couldn't find any physical disk") if perdisk: for disk, fields in rawdict.items(): rawdict[disk] = _nt_sys_diskio(*fields) return rawdict else: return _nt_sys_diskio(*[sum(x) for x in zip(*rawdict.values())]) # ===================================================================== # --- network related functions # ===================================================================== def net_io_counters(pernic=False): """Return network I/O statistics as a namedtuple including the following fields: - bytes_sent: number of bytes sent - bytes_recv: number of bytes received - packets_sent: number of packets sent - packets_recv: number of packets received - errin: total number of errors while receiving - errout: total number of errors while sending - dropin: total number of incoming packets which were dropped - dropout: total number of outgoing packets which were dropped (always 0 on OSX and BSD) If pernic is True return the same information for every network interface installed on the system as a dictionary with network interface names as the keys and the namedtuple described above as the values. """ rawdict = _psplatform.net_io_counters() if not rawdict: raise RuntimeError("couldn't find any network interface") if pernic: for nic, fields in rawdict.items(): rawdict[nic] = _nt_sys_netio(*fields) return rawdict else: return _nt_sys_netio(*[sum(x) for x in zip(*rawdict.values())]) def net_connections(kind='inet'): """Return system-wide connections as a list of (fd, family, type, laddr, raddr, status, pid) namedtuples. In case of limited privileges 'fd' and 'pid' may be set to -1 and None respectively. The 'kind' parameter filters for connections that fit the following criteria: Kind Value Connections using inet IPv4 and IPv6 inet4 IPv4 inet6 IPv6 tcp TCP tcp4 TCP over IPv4 tcp6 TCP over IPv6 udp UDP udp4 UDP over IPv4 udp6 UDP over IPv6 unix UNIX socket (both UDP and TCP protocols) all the sum of all the possible families and protocols """ return _psplatform.net_connections(kind) # ===================================================================== # --- other system related functions # ===================================================================== def boot_time(): """Return the system boot time expressed in seconds since the epoch. This is also available as psutil.BOOT_TIME. """ # Note: we are not caching this because it is subject to # system clock updates. return _psplatform.boot_time() def users(): """Return users currently connected on the system as a list of namedtuples including the following fields. - user: the name of the user - terminal: the tty or pseudo-tty associated with the user, if any. - host: the host name associated with the entry, if any. - started: the creation time as a floating point number expressed in seconds since the epoch. """ return _psplatform.users() # ===================================================================== # --- deprecated functions # ===================================================================== @_deprecated(replacement="psutil.pids()") def get_pid_list(): return pids() @_deprecated(replacement="list(process_iter())") def get_process_list(): return list(process_iter()) @_deprecated(replacement="psutil.users()") def get_users(): return users() @_deprecated(replacement="psutil.virtual_memory()") def phymem_usage(): """Return the amount of total, used and free physical memory on the system in bytes plus the percentage usage. Deprecated; use psutil.virtual_memory() instead. """ return virtual_memory() @_deprecated(replacement="psutil.swap_memory()") def virtmem_usage(): return swap_memory() @_deprecated(replacement="psutil.phymem_usage().free") def avail_phymem(): return phymem_usage().free @_deprecated(replacement="psutil.phymem_usage().used") def used_phymem(): return phymem_usage().used @_deprecated(replacement="psutil.virtmem_usage().total") def total_virtmem(): return virtmem_usage().total @_deprecated(replacement="psutil.virtmem_usage().used") def used_virtmem(): return virtmem_usage().used @_deprecated(replacement="psutil.virtmem_usage().free") def avail_virtmem(): return virtmem_usage().free @_deprecated(replacement="psutil.net_io_counters()") def network_io_counters(pernic=False): return net_io_counters(pernic) def test(): """List info of all currently running processes emulating ps aux output. """ import datetime today_day = datetime.date.today() templ = "%-10s %5s %4s %4s %7s %7s %-13s %5s %7s %s" attrs = ['pid', 'cpu_percent', 'memory_percent', 'name', 'cpu_times', 'create_time', 'memory_info'] if _POSIX: attrs.append('uids') attrs.append('terminal') print(templ % ("USER", "PID", "%CPU", "%MEM", "VSZ", "RSS", "TTY", "START", "TIME", "COMMAND")) for p in process_iter(): try: pinfo = p.as_dict(attrs, ad_value='') except NoSuchProcess: pass else: if pinfo['create_time']: ctime = datetime.datetime.fromtimestamp(pinfo['create_time']) if ctime.date() == today_day: ctime = ctime.strftime("%H:%M") else: ctime = ctime.strftime("%b%d") else: ctime = '' cputime = time.strftime("%M:%S", time.localtime(sum(pinfo['cpu_times']))) try: user = p.username() except KeyError: if _POSIX: if pinfo['uids']: user = str(pinfo['uids'].real) else: user = '' else: raise except Error: user = '' if _WINDOWS and '\\' in user: user = user.split('\\')[1] vms = pinfo['memory_info'] and \ int(pinfo['memory_info'].vms / 1024) or '?' rss = pinfo['memory_info'] and \ int(pinfo['memory_info'].rss / 1024) or '?' memp = pinfo['memory_percent'] and \ round(pinfo['memory_percent'], 1) or '?' print(templ % ( user[:10], pinfo['pid'], pinfo['cpu_percent'], memp, vms, rss, pinfo.get('terminal', '') or '?', ctime, cputime, pinfo['name'].strip() or '?')) def _replace_module(): """Dirty hack to replace the module object in order to access deprecated module constants, see: http://www.dr-josiah.com/2013/12/properties-on-python-modules.html """ class ModuleWrapper(object): def __repr__(self): return repr(self._module) __str__ = __repr__ @property def NUM_CPUS(self): msg = "NUM_CPUS constant is deprecated; use cpu_count() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return cpu_count() @property def BOOT_TIME(self): msg = "BOOT_TIME constant is deprecated; use boot_time() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return boot_time() @property def TOTAL_PHYMEM(self): msg = "TOTAL_PHYMEM constant is deprecated; " \ "use virtual_memory().total instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return virtual_memory().total mod = ModuleWrapper() mod.__dict__ = globals() mod._module = sys.modules[__name__] sys.modules[__name__] = mod _replace_module() del memoize, division, _replace_module if sys.version_info < (3, 0): del num if __name__ == "__main__": test() ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_common.py ================================================ # /usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Common objects shared by all _ps* modules.""" from __future__ import division import errno import functools import os import socket import stat import warnings try: import threading except ImportError: import dummy_threading as threading from collections import namedtuple from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM # --- constants AF_INET6 = getattr(socket, 'AF_INET6', None) AF_UNIX = getattr(socket, 'AF_UNIX', None) STATUS_RUNNING = "running" STATUS_SLEEPING = "sleeping" STATUS_DISK_SLEEP = "disk-sleep" STATUS_STOPPED = "stopped" STATUS_TRACING_STOP = "tracing-stop" STATUS_ZOMBIE = "zombie" STATUS_DEAD = "dead" STATUS_WAKE_KILL = "wake-kill" STATUS_WAKING = "waking" STATUS_IDLE = "idle" # BSD STATUS_LOCKED = "locked" # BSD STATUS_WAITING = "waiting" # BSD CONN_ESTABLISHED = "ESTABLISHED" CONN_SYN_SENT = "SYN_SENT" CONN_SYN_RECV = "SYN_RECV" CONN_FIN_WAIT1 = "FIN_WAIT1" CONN_FIN_WAIT2 = "FIN_WAIT2" CONN_TIME_WAIT = "TIME_WAIT" CONN_CLOSE = "CLOSE" CONN_CLOSE_WAIT = "CLOSE_WAIT" CONN_LAST_ACK = "LAST_ACK" CONN_LISTEN = "LISTEN" CONN_CLOSING = "CLOSING" CONN_NONE = "NONE" # --- functions def usage_percent(used, total, _round=None): """Calculate percentage usage of 'used' against 'total'.""" try: ret = (used / total) * 100 except ZeroDivisionError: ret = 0 if _round is not None: return round(ret, _round) else: return ret def memoize(fun): """A simple memoize decorator for functions supporting (hashable) positional arguments. It also provides a cache_clear() function for clearing the cache: >>> @memoize ... def foo() ... return 1 ... >>> foo() 1 >>> foo.cache_clear() >>> """ @functools.wraps(fun) def wrapper(*args, **kwargs): key = (args, frozenset(sorted(kwargs.items()))) lock.acquire() try: try: return cache[key] except KeyError: ret = cache[key] = fun(*args, **kwargs) finally: lock.release() return ret def cache_clear(): """Clear cache.""" lock.acquire() try: cache.clear() finally: lock.release() lock = threading.RLock() cache = {} wrapper.cache_clear = cache_clear return wrapper # http://code.activestate.com/recipes/577819-deprecated-decorator/ def deprecated(replacement=None): """A decorator which can be used to mark functions as deprecated.""" def outer(fun): msg = "psutil.%s is deprecated" % fun.__name__ if replacement is not None: msg += "; use %s instead" % replacement if fun.__doc__ is None: fun.__doc__ = msg @functools.wraps(fun) def inner(*args, **kwargs): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return fun(*args, **kwargs) return inner return outer def deprecated_method(replacement): """A decorator which can be used to mark a method as deprecated 'replcement' is the method name which will be called instead. """ def outer(fun): msg = "%s() is deprecated; use %s() instead" % ( fun.__name__, replacement) if fun.__doc__ is None: fun.__doc__ = msg @functools.wraps(fun) def inner(self, *args, **kwargs): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return getattr(self, replacement)(*args, **kwargs) return inner return outer def isfile_strict(path): """Same as os.path.isfile() but does not swallow EACCES / EPERM exceptions, see: http://mail.python.org/pipermail/python-dev/2012-June/120787.html """ try: st = os.stat(path) except OSError as err: if err.errno in (errno.EPERM, errno.EACCES): raise return False else: return stat.S_ISREG(st.st_mode) # --- Process.connections() 'kind' parameter mapping conn_tmap = { "all": ([AF_INET, AF_INET6, AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]), "tcp": ([AF_INET, AF_INET6], [SOCK_STREAM]), "tcp4": ([AF_INET], [SOCK_STREAM]), "udp": ([AF_INET, AF_INET6], [SOCK_DGRAM]), "udp4": ([AF_INET], [SOCK_DGRAM]), "inet": ([AF_INET, AF_INET6], [SOCK_STREAM, SOCK_DGRAM]), "inet4": ([AF_INET], [SOCK_STREAM, SOCK_DGRAM]), "inet6": ([AF_INET6], [SOCK_STREAM, SOCK_DGRAM]), } if AF_INET6 is not None: conn_tmap.update({ "tcp6": ([AF_INET6], [SOCK_STREAM]), "udp6": ([AF_INET6], [SOCK_DGRAM]), }) if AF_UNIX is not None: conn_tmap.update({ "unix": ([AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]), }) del AF_INET, AF_INET6, AF_UNIX, SOCK_STREAM, SOCK_DGRAM, socket # --- namedtuples for psutil.* system-related functions # psutil.swap_memory() sswap = namedtuple('sswap', ['total', 'used', 'free', 'percent', 'sin', 'sout']) # psutil.disk_usage() sdiskusage = namedtuple('sdiskusage', ['total', 'used', 'free', 'percent']) # psutil.disk_io_counters() sdiskio = namedtuple('sdiskio', ['read_count', 'write_count', 'read_bytes', 'write_bytes', 'read_time', 'write_time']) # psutil.disk_partitions() sdiskpart = namedtuple('sdiskpart', ['device', 'mountpoint', 'fstype', 'opts']) # psutil.net_io_counters() snetio = namedtuple('snetio', ['bytes_sent', 'bytes_recv', 'packets_sent', 'packets_recv', 'errin', 'errout', 'dropin', 'dropout']) # psutil.users() suser = namedtuple('suser', ['name', 'terminal', 'host', 'started']) # psutil.net_connections() sconn = namedtuple('sconn', ['fd', 'family', 'type', 'laddr', 'raddr', 'status', 'pid']) # --- namedtuples for psutil.Process methods # psutil.Process.memory_info() pmem = namedtuple('pmem', ['rss', 'vms']) # psutil.Process.cpu_times() pcputimes = namedtuple('pcputimes', ['user', 'system']) # psutil.Process.open_files() popenfile = namedtuple('popenfile', ['path', 'fd']) # psutil.Process.threads() pthread = namedtuple('pthread', ['id', 'user_time', 'system_time']) # psutil.Process.uids() puids = namedtuple('puids', ['real', 'effective', 'saved']) # psutil.Process.gids() pgids = namedtuple('pgids', ['real', 'effective', 'saved']) # psutil.Process.io_counters() pio = namedtuple('pio', ['read_count', 'write_count', 'read_bytes', 'write_bytes']) # psutil.Process.ionice() pionice = namedtuple('pionice', ['ioclass', 'value']) # psutil.Process.ctx_switches() pctxsw = namedtuple('pctxsw', ['voluntary', 'involuntary']) # --- misc # backward compatibility layer for Process.connections() ntuple class pconn( namedtuple('pconn', ['fd', 'family', 'type', 'laddr', 'raddr', 'status'])): __slots__ = () @property def local_address(self): warnings.warn("'local_address' field is deprecated; use 'laddr'" "instead", category=DeprecationWarning, stacklevel=2) return self.laddr @property def remote_address(self): warnings.warn("'remote_address' field is deprecated; use 'raddr'" "instead", category=DeprecationWarning, stacklevel=2) return self.raddr ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_compat.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Module which provides compatibility with older Python versions.""" __all__ = ["PY3", "int", "long", "xrange", "exec_", "callable", "lru_cache"] import collections import functools import sys try: import __builtin__ except ImportError: import builtins as __builtin__ # py3 PY3 = sys.version_info[0] == 3 if PY3: int = int long = int xrange = range unicode = str basestring = str exec_ = getattr(__builtin__, "exec") else: int = int long = long xrange = xrange unicode = unicode basestring = basestring def exec_(code, globs=None, locs=None): if globs is None: frame = sys._getframe(1) globs = frame.f_globals if locs is None: locs = frame.f_locals del frame elif locs is None: locs = globs exec("""exec code in globs, locs""") # removed in 3.0, reintroduced in 3.2 try: callable = callable except NameError: def callable(obj): return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) # --- stdlib additions # py 3.2 functools.lru_cache # Taken from: http://code.activestate.com/recipes/578078 # Credit: Raymond Hettinger try: from functools import lru_cache except ImportError: try: from threading import RLock except ImportError: from dummy_threading import RLock _CacheInfo = collections.namedtuple( "CacheInfo", ["hits", "misses", "maxsize", "currsize"]) class _HashedSeq(list): __slots__ = 'hashvalue' def __init__(self, tup, hash=hash): self[:] = tup self.hashvalue = hash(tup) def __hash__(self): return self.hashvalue def _make_key(args, kwds, typed, kwd_mark=(object(), ), fasttypes=set((int, str, frozenset, type(None))), sorted=sorted, tuple=tuple, type=type, len=len): key = args if kwds: sorted_items = sorted(kwds.items()) key += kwd_mark for item in sorted_items: key += item if typed: key += tuple(type(v) for v in args) if kwds: key += tuple(type(v) for k, v in sorted_items) elif len(key) == 1 and type(key[0]) in fasttypes: return key[0] return _HashedSeq(key) def lru_cache(maxsize=100, typed=False): """Least-recently-used cache decorator, see: http://docs.python.org/3/library/functools.html#functools.lru_cache """ def decorating_function(user_function): cache = dict() stats = [0, 0] HITS, MISSES = 0, 1 make_key = _make_key cache_get = cache.get _len = len lock = RLock() root = [] root[:] = [root, root, None, None] nonlocal_root = [root] PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 if maxsize == 0: def wrapper(*args, **kwds): result = user_function(*args, **kwds) stats[MISSES] += 1 return result elif maxsize is None: def wrapper(*args, **kwds): key = make_key(args, kwds, typed) result = cache_get(key, root) if result is not root: stats[HITS] += 1 return result result = user_function(*args, **kwds) cache[key] = result stats[MISSES] += 1 return result else: def wrapper(*args, **kwds): if kwds or typed: key = make_key(args, kwds, typed) else: key = args lock.acquire() try: link = cache_get(key) if link is not None: root, = nonlocal_root link_prev, link_next, key, result = link link_prev[NEXT] = link_next link_next[PREV] = link_prev last = root[PREV] last[NEXT] = root[PREV] = link link[PREV] = last link[NEXT] = root stats[HITS] += 1 return result finally: lock.release() result = user_function(*args, **kwds) lock.acquire() try: root, = nonlocal_root if key in cache: pass elif _len(cache) >= maxsize: oldroot = root oldroot[KEY] = key oldroot[RESULT] = result root = nonlocal_root[0] = oldroot[NEXT] oldkey = root[KEY] root[KEY] = root[RESULT] = None del cache[oldkey] cache[key] = oldroot else: last = root[PREV] link = [last, root, key, result] last[NEXT] = root[PREV] = cache[key] = link stats[MISSES] += 1 finally: lock.release() return result def cache_info(): """Report cache statistics""" lock.acquire() try: return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache)) finally: lock.release() def cache_clear(): """Clear the cache and cache statistics""" lock.acquire() try: cache.clear() root = nonlocal_root[0] root[:] = [root, root, None, None] stats[:] = [0, 0] finally: lock.release() wrapper.__wrapped__ = user_function wrapper.cache_info = cache_info wrapper.cache_clear = cache_clear return functools.update_wrapper(wrapper, user_function) return decorating_function ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_psbsd.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """FreeBSD platform implementation.""" import errno import functools import os import sys from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import conn_tmap, usage_percent import _psutil_bsd as cext import _psutil_posix __extra__all__ = [] # --- constants PROC_STATUSES = { cext.SSTOP: _common.STATUS_STOPPED, cext.SSLEEP: _common.STATUS_SLEEPING, cext.SRUN: _common.STATUS_RUNNING, cext.SIDL: _common.STATUS_IDLE, cext.SWAIT: _common.STATUS_WAITING, cext.SLOCK: _common.STATUS_LOCKED, cext.SZOMB: _common.STATUS_ZOMBIE, } TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } PAGESIZE = os.sysconf("SC_PAGE_SIZE") # extend base mem ntuple with BSD-specific memory metrics svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'buffers', 'cached', 'shared', 'wired']) scputimes = namedtuple( 'scputimes', ['user', 'nice', 'system', 'idle', 'irq']) pextmem = namedtuple('pextmem', ['rss', 'vms', 'text', 'data', 'stack']) pmmap_grouped = namedtuple( 'pmmap_grouped', 'path rss, private, ref_count, shadow_count') pmmap_ext = namedtuple( 'pmmap_ext', 'addr, perms path rss, private, ref_count, shadow_count') # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None def virtual_memory(): """System virtual memory as a namedtuple.""" mem = cext.virtual_mem() total, free, active, inactive, wired, cached, buffers, shared = mem avail = inactive + cached + free used = active + wired + cached percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, buffers, cached, shared, wired) def swap_memory(): """System swap memory as (total, used, free, sin, sout) namedtuple.""" total, used, free, sin, sout = [x * PAGESIZE for x in cext.swap_mem()] percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin, sout) def cpu_times(): """Return system per-CPU times as a namedtuple""" user, nice, system, idle, irq = cext.cpu_times() return scputimes(user, nice, system, idle, irq) if hasattr(cext, "per_cpu_times"): def per_cpu_times(): """Return system CPU times as a namedtuple""" ret = [] for cpu_t in cext.per_cpu_times(): user, nice, system, idle, irq = cpu_t item = scputimes(user, nice, system, idle, irq) ret.append(item) return ret else: # XXX # Ok, this is very dirty. # On FreeBSD < 8 we cannot gather per-cpu information, see: # https://github.com/giampaolo/psutil/issues/226 # If num cpus > 1, on first call we return single cpu times to avoid a # crash at psutil import time. # Next calls will fail with NotImplementedError def per_cpu_times(): if cpu_count_logical() == 1: return [cpu_times()] if per_cpu_times.__called__: raise NotImplementedError("supported only starting from FreeBSD 8") per_cpu_times.__called__ = True return [cpu_times()] per_cpu_times.__called__ = False def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" # From the C module we'll get an XML string similar to this: # http://manpages.ubuntu.com/manpages/precise/man4/smp.4freebsd.html # We may get None in case "sysctl kern.sched.topology_spec" # is not supported on this BSD version, in which case we'll mimic # os.cpu_count() and return None. s = cext.cpu_count_phys() if s is not None: # get rid of padding chars appended at the end of the string index = s.rfind("") if index != -1: s = s[:index + 9] if sys.version_info >= (2, 5): import xml.etree.ElementTree as ET root = ET.fromstring(s) return len(root.findall('group/children/group/cpu')) or None else: s = s[s.find(''):] return s.count("> if err.errno in (errno.EINVAL, errno.EDEADLK): allcpus = tuple(range(len(per_cpu_times()))) for cpu in cpus: if cpu not in allcpus: raise ValueError("invalid CPU #%i (choose between %s)" % (cpu, allcpus)) raise ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_pslinux.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Linux platform implementation.""" from __future__ import division import base64 import errno import functools import os import re import socket import struct import sys import warnings from collections import namedtuple, defaultdict from psutil import _common from psutil import _psposix from psutil._common import (isfile_strict, usage_percent, deprecated) from psutil._compat import PY3 import _psutil_linux as cext import _psutil_posix __extra__all__ = [ # io prio constants "IOPRIO_CLASS_NONE", "IOPRIO_CLASS_RT", "IOPRIO_CLASS_BE", "IOPRIO_CLASS_IDLE", # connection status constants "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1", "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT", "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", # other "phymem_buffers", "cached_phymem"] # --- constants HAS_PRLIMIT = hasattr(cext, "linux_prlimit") # RLIMIT_* constants, not guaranteed to be present on all kernels if HAS_PRLIMIT: for name in dir(cext): if name.startswith('RLIM'): __extra__all__.append(name) # Number of clock ticks per second CLOCK_TICKS = os.sysconf("SC_CLK_TCK") PAGESIZE = os.sysconf("SC_PAGE_SIZE") BOOT_TIME = None # set later DEFAULT_ENCODING = sys.getdefaultencoding() # ioprio_* constants http://linux.die.net/man/2/ioprio_get IOPRIO_CLASS_NONE = 0 IOPRIO_CLASS_RT = 1 IOPRIO_CLASS_BE = 2 IOPRIO_CLASS_IDLE = 3 # taken from /fs/proc/array.c PROC_STATUSES = { "R": _common.STATUS_RUNNING, "S": _common.STATUS_SLEEPING, "D": _common.STATUS_DISK_SLEEP, "T": _common.STATUS_STOPPED, "t": _common.STATUS_TRACING_STOP, "Z": _common.STATUS_ZOMBIE, "X": _common.STATUS_DEAD, "x": _common.STATUS_DEAD, "K": _common.STATUS_WAKE_KILL, "W": _common.STATUS_WAKING } # http://students.mimuw.edu.pl/lxr/source/include/net/tcp_states.h TCP_STATUSES = { "01": _common.CONN_ESTABLISHED, "02": _common.CONN_SYN_SENT, "03": _common.CONN_SYN_RECV, "04": _common.CONN_FIN_WAIT1, "05": _common.CONN_FIN_WAIT2, "06": _common.CONN_TIME_WAIT, "07": _common.CONN_CLOSE, "08": _common.CONN_CLOSE_WAIT, "09": _common.CONN_LAST_ACK, "0A": _common.CONN_LISTEN, "0B": _common.CONN_CLOSING } # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- named tuples def _get_cputimes_fields(): """Return a namedtuple of variable fields depending on the CPU times available on this Linux kernel version which may be: (user, nice, system, idle, iowait, irq, softirq, [steal, [guest, [guest_nice]]]) """ with open('/proc/stat', 'rb') as f: values = f.readline().split()[1:] fields = ['user', 'nice', 'system', 'idle', 'iowait', 'irq', 'softirq'] vlen = len(values) if vlen >= 8: # Linux >= 2.6.11 fields.append('steal') if vlen >= 9: # Linux >= 2.6.24 fields.append('guest') if vlen >= 10: # Linux >= 3.2.0 fields.append('guest_nice') return fields scputimes = namedtuple('scputimes', _get_cputimes_fields()) svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'buffers', 'cached']) pextmem = namedtuple('pextmem', 'rss vms shared text lib data dirty') pmmap_grouped = namedtuple( 'pmmap_grouped', ['path', 'rss', 'size', 'pss', 'shared_clean', 'shared_dirty', 'private_clean', 'private_dirty', 'referenced', 'anonymous', 'swap']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # --- system memory def virtual_memory(): total, free, buffers, shared, _, _ = cext.linux_sysinfo() cached = active = inactive = None with open('/proc/meminfo', 'rb') as f: for line in f: if line.startswith(b"Cached:"): cached = int(line.split()[1]) * 1024 elif line.startswith(b"Active:"): active = int(line.split()[1]) * 1024 elif line.startswith(b"Inactive:"): inactive = int(line.split()[1]) * 1024 if (cached is not None and active is not None and inactive is not None): break else: # we might get here when dealing with exotic Linux flavors, see: # https://github.com/giampaolo/psutil/issues/313 msg = "'cached', 'active' and 'inactive' memory stats couldn't " \ "be determined and were set to 0" warnings.warn(msg, RuntimeWarning) cached = active = inactive = 0 avail = free + buffers + cached used = total - free percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, buffers, cached) def swap_memory(): _, _, _, _, total, free = cext.linux_sysinfo() used = total - free percent = usage_percent(used, total, _round=1) # get pgin/pgouts with open("/proc/vmstat", "rb") as f: sin = sout = None for line in f: # values are expressed in 4 kilo bytes, we want bytes instead if line.startswith(b'pswpin'): sin = int(line.split(b' ')[1]) * 4 * 1024 elif line.startswith(b'pswpout'): sout = int(line.split(b' ')[1]) * 4 * 1024 if sin is not None and sout is not None: break else: # we might get here when dealing with exotic Linux flavors, see: # https://github.com/giampaolo/psutil/issues/313 msg = "'sin' and 'sout' swap memory stats couldn't " \ "be determined and were set to 0" warnings.warn(msg, RuntimeWarning) sin = sout = 0 return _common.sswap(total, used, free, percent, sin, sout) @deprecated(replacement='psutil.virtual_memory().cached') def cached_phymem(): return virtual_memory().cached @deprecated(replacement='psutil.virtual_memory().buffers') def phymem_buffers(): return virtual_memory().buffers # --- CPUs def cpu_times(): """Return a named tuple representing the following system-wide CPU times: (user, nice, system, idle, iowait, irq, softirq [steal, [guest, [guest_nice]]]) Last 3 fields may not be available on all Linux kernel versions. """ with open('/proc/stat', 'rb') as f: values = f.readline().split() fields = values[1:len(scputimes._fields) + 1] fields = [float(x) / CLOCK_TICKS for x in fields] return scputimes(*fields) def per_cpu_times(): """Return a list of namedtuple representing the CPU times for every CPU available on the system. """ cpus = [] with open('/proc/stat', 'rb') as f: # get rid of the first line which refers to system wide CPU stats f.readline() for line in f: if line.startswith(b'cpu'): values = line.split() fields = values[1:len(scputimes._fields) + 1] fields = [float(x) / CLOCK_TICKS for x in fields] entry = scputimes(*fields) cpus.append(entry) return cpus def cpu_count_logical(): """Return the number of logical CPUs in the system.""" try: return os.sysconf("SC_NPROCESSORS_ONLN") except ValueError: # as a second fallback we try to parse /proc/cpuinfo num = 0 with open('/proc/cpuinfo', 'rb') as f: for line in f: if line.lower().startswith(b'processor'): num += 1 # unknown format (e.g. amrel/sparc architectures), see: # https://github.com/giampaolo/psutil/issues/200 # try to parse /proc/stat as a last resort if num == 0: search = re.compile('cpu\d') with open('/proc/stat', 'rt') as f: for line in f: line = line.split(' ')[0] if search.match(line): num += 1 if num == 0: # mimic os.cpu_count() return None return num def cpu_count_physical(): """Return the number of physical CPUs in the system.""" with open('/proc/cpuinfo', 'rb') as f: found = set() for line in f: if line.lower().startswith(b'physical id'): found.add(line.strip()) # mimic os.cpu_count() return len(found) if found else None # --- other system functions def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() for item in rawlist: user, tty, hostname, tstamp, user_process = item # note: the underlying C function includes entries about # system boot, run level and others. We might want # to use them in the future. if not user_process: continue if hostname == ':0.0': hostname = 'localhost' nt = _common.suser(user, tty or None, hostname, tstamp) retlist.append(nt) return retlist def boot_time(): """Return the system boot time expressed in seconds since the epoch.""" global BOOT_TIME with open('/proc/stat', 'rb') as f: for line in f: if line.startswith(b'btime'): ret = float(line.strip().split()[1]) BOOT_TIME = ret return ret raise RuntimeError("line 'btime' not found") # --- processes def pids(): """Returns a list of PIDs currently running on the system.""" return [int(x) for x in os.listdir(b'/proc') if x.isdigit()] def pid_exists(pid): """Check For the existence of a unix pid.""" return _psposix.pid_exists(pid) # --- network class Connections: """A wrapper on top of /proc/net/* files, retrieving per-process and system-wide open connections (TCP, UDP, UNIX) similarly to "netstat -an". Note: in case of UNIX sockets we're only able to determine the local endpoint/path, not the one it's connected to. According to [1] it would be possible but not easily. [1] http://serverfault.com/a/417946 """ def __init__(self): tcp4 = ("tcp", socket.AF_INET, socket.SOCK_STREAM) tcp6 = ("tcp6", socket.AF_INET6, socket.SOCK_STREAM) udp4 = ("udp", socket.AF_INET, socket.SOCK_DGRAM) udp6 = ("udp6", socket.AF_INET6, socket.SOCK_DGRAM) unix = ("unix", socket.AF_UNIX, None) self.tmap = { "all": (tcp4, tcp6, udp4, udp6, unix), "tcp": (tcp4, tcp6), "tcp4": (tcp4,), "tcp6": (tcp6,), "udp": (udp4, udp6), "udp4": (udp4,), "udp6": (udp6,), "unix": (unix,), "inet": (tcp4, tcp6, udp4, udp6), "inet4": (tcp4, udp4), "inet6": (tcp6, udp6), } def get_proc_inodes(self, pid): inodes = defaultdict(list) for fd in os.listdir("/proc/%s/fd" % pid): try: inode = os.readlink("/proc/%s/fd/%s" % (pid, fd)) except OSError: # TODO: need comment here continue else: if inode.startswith('socket:['): # the process is using a socket inode = inode[8:][:-1] inodes[inode].append((pid, int(fd))) return inodes def get_all_inodes(self): inodes = {} for pid in pids(): try: inodes.update(self.get_proc_inodes(pid)) except OSError as err: # os.listdir() is gonna raise a lot of access denied # exceptions in case of unprivileged user; that's fine # as we'll just end up returning a connection with PID # and fd set to None anyway. # Both netstat -an and lsof does the same so it's # unlikely we can do any better. # ENOENT just means a PID disappeared on us. if err.errno not in ( errno.ENOENT, errno.ESRCH, errno.EPERM, errno.EACCES): raise return inodes def decode_address(self, addr, family): """Accept an "ip:port" address as displayed in /proc/net/* and convert it into a human readable form, like: "0500000A:0016" -> ("10.0.0.5", 22) "0000000000000000FFFF00000100007F:9E49" -> ("::ffff:127.0.0.1", 40521) The IP address portion is a little or big endian four-byte hexadecimal number; that is, the least significant byte is listed first, so we need to reverse the order of the bytes to convert it to an IP address. The port is represented as a two-byte hexadecimal number. Reference: http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html """ ip, port = addr.split(':') port = int(port, 16) # this usually refers to a local socket in listen mode with # no end-points connected if not port: return () if PY3: ip = ip.encode('ascii') if family == socket.AF_INET: # see: https://github.com/giampaolo/psutil/issues/201 if sys.byteorder == 'little': ip = socket.inet_ntop(family, base64.b16decode(ip)[::-1]) else: ip = socket.inet_ntop(family, base64.b16decode(ip)) else: # IPv6 # old version - let's keep it, just in case... # ip = ip.decode('hex') # return socket.inet_ntop(socket.AF_INET6, # ''.join(ip[i:i+4][::-1] for i in xrange(0, 16, 4))) ip = base64.b16decode(ip) # see: https://github.com/giampaolo/psutil/issues/201 if sys.byteorder == 'little': ip = socket.inet_ntop( socket.AF_INET6, struct.pack('>4I', *struct.unpack('<4I', ip))) else: ip = socket.inet_ntop( socket.AF_INET6, struct.pack('<4I', *struct.unpack('<4I', ip))) return (ip, port) def process_inet(self, file, family, type_, inodes, filter_pid=None): """Parse /proc/net/tcp* and /proc/net/udp* files.""" if file.endswith('6') and not os.path.exists(file): # IPv6 not supported return with open(file, 'rt') as f: f.readline() # skip the first line for line in f: _, laddr, raddr, status, _, _, _, _, _, inode = \ line.split()[:10] if inode in inodes: # # We assume inet sockets are unique, so we error # # out if there are multiple references to the # # same inode. We won't do this for UNIX sockets. # if len(inodes[inode]) > 1 and family != socket.AF_UNIX: # raise ValueError("ambiguos inode with multiple " # "PIDs references") pid, fd = inodes[inode][0] else: pid, fd = None, -1 if filter_pid is not None and filter_pid != pid: continue else: if type_ == socket.SOCK_STREAM: status = TCP_STATUSES[status] else: status = _common.CONN_NONE laddr = self.decode_address(laddr, family) raddr = self.decode_address(raddr, family) yield (fd, family, type_, laddr, raddr, status, pid) def process_unix(self, file, family, inodes, filter_pid=None): """Parse /proc/net/unix files.""" with open(file, 'rt') as f: f.readline() # skip the first line for line in f: tokens = line.split() _, _, _, _, type_, _, inode = tokens[0:7] if inode in inodes: # With UNIX sockets we can have a single inode # referencing many file descriptors. pairs = inodes[inode] else: pairs = [(None, -1)] for pid, fd in pairs: if filter_pid is not None and filter_pid != pid: continue else: if len(tokens) == 8: path = tokens[-1] else: path = "" type_ = int(type_) raddr = None status = _common.CONN_NONE yield (fd, family, type_, path, raddr, status, pid) def retrieve(self, kind, pid=None): if kind not in self.tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in self.tmap]))) if pid is not None: inodes = self.get_proc_inodes(pid) if not inodes: # no connections for this process return [] else: inodes = self.get_all_inodes() ret = [] for f, family, type_ in self.tmap[kind]: if family in (socket.AF_INET, socket.AF_INET6): ls = self.process_inet( "/proc/net/%s" % f, family, type_, inodes, filter_pid=pid) else: ls = self.process_unix( "/proc/net/%s" % f, family, inodes, filter_pid=pid) for fd, family, type_, laddr, raddr, status, bound_pid in ls: if pid: conn = _common.pconn(fd, family, type_, laddr, raddr, status) else: conn = _common.sconn(fd, family, type_, laddr, raddr, status, bound_pid) ret.append(conn) return ret _connections = Connections() def net_connections(kind='inet'): """Return system-wide open connections.""" return _connections.retrieve(kind) def net_io_counters(): """Return network I/O statistics for every network interface installed on the system as a dict of raw tuples. """ with open("/proc/net/dev", "rt") as f: lines = f.readlines() retdict = {} for line in lines[2:]: colon = line.rfind(':') assert colon > 0, repr(line) name = line[:colon].strip() fields = line[colon + 1:].strip().split() bytes_recv = int(fields[0]) packets_recv = int(fields[1]) errin = int(fields[2]) dropin = int(fields[3]) bytes_sent = int(fields[8]) packets_sent = int(fields[9]) errout = int(fields[10]) dropout = int(fields[11]) retdict[name] = (bytes_sent, bytes_recv, packets_sent, packets_recv, errin, errout, dropin, dropout) return retdict # --- disks def disk_io_counters(): """Return disk I/O statistics for every disk installed on the system as a dict of raw tuples. """ # man iostat states that sectors are equivalent with blocks and # have a size of 512 bytes since 2.4 kernels. This value is # needed to calculate the amount of disk I/O in bytes. SECTOR_SIZE = 512 # determine partitions we want to look for partitions = [] with open("/proc/partitions", "rt") as f: lines = f.readlines()[2:] for line in reversed(lines): _, _, _, name = line.split() if name[-1].isdigit(): # we're dealing with a partition (e.g. 'sda1'); 'sda' will # also be around but we want to omit it partitions.append(name) else: if not partitions or not partitions[-1].startswith(name): # we're dealing with a disk entity for which no # partitions have been defined (e.g. 'sda' but # 'sda1' was not around), see: # https://github.com/giampaolo/psutil/issues/338 partitions.append(name) # retdict = {} with open("/proc/diskstats", "rt") as f: lines = f.readlines() for line in lines: # http://www.mjmwired.net/kernel/Documentation/iostats.txt fields = line.split() if len(fields) > 7: _, _, name, reads, _, rbytes, rtime, writes, _, wbytes, wtime = \ fields[:11] else: # from kernel 2.6.0 to 2.6.25 _, _, name, reads, rbytes, writes, wbytes = fields rtime, wtime = 0, 0 if name in partitions: rbytes = int(rbytes) * SECTOR_SIZE wbytes = int(wbytes) * SECTOR_SIZE reads = int(reads) writes = int(writes) rtime = int(rtime) wtime = int(wtime) retdict[name] = (reads, writes, rbytes, wbytes, rtime, wtime) return retdict def disk_partitions(all=False): """Return mounted disk partitions as a list of nameduples""" phydevs = [] with open("/proc/filesystems", "r") as f: for line in f: if not line.startswith("nodev"): phydevs.append(line.strip()) retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: if device == '' or fstype not in phydevs: continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist disk_usage = _psposix.disk_usage # --- decorators def wrap_exceptions(fun): """Decorator which translates bare OSError and IOError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except EnvironmentError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise # ENOENT (no such file or directory) gets raised on open(). # ESRCH (no such process) can get raised on read() if # process is gone in meantime. if err.errno in (errno.ENOENT, errno.ESRCH): raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Linux process implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): fname = "/proc/%s/stat" % self.pid kw = dict(encoding=DEFAULT_ENCODING) if PY3 else dict() with open(fname, "rt", **kw) as f: # XXX - gets changed later and probably needs refactoring return f.read().split(' ')[1].replace('(', '').replace(')', '') def exe(self): try: exe = os.readlink("/proc/%s/exe" % self.pid) except (OSError, IOError) as err: if err.errno in (errno.ENOENT, errno.ESRCH): # no such file error; might be raised also if the # path actually exists for system processes with # low pids (about 0-20) if os.path.lexists("/proc/%s" % self.pid): return "" else: # ok, it is a process which has gone away raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise # readlink() might return paths containing null bytes ('\x00'). # Certain names have ' (deleted)' appended. Usually this is # bogus as the file actually exists. Either way that's not # important as we don't want to discriminate executables which # have been deleted. exe = exe.split('\x00')[0] if exe.endswith(' (deleted)') and not os.path.exists(exe): exe = exe[:-10] return exe @wrap_exceptions def cmdline(self): fname = "/proc/%s/cmdline" % self.pid kw = dict(encoding=DEFAULT_ENCODING) if PY3 else dict() with open(fname, "rt", **kw) as f: return [x for x in f.read().split('\x00') if x] @wrap_exceptions def terminal(self): tmap = _psposix._get_terminal_map() with open("/proc/%s/stat" % self.pid, 'rb') as f: tty_nr = int(f.read().split(b' ')[6]) try: return tmap[tty_nr] except KeyError: return None if os.path.exists('/proc/%s/io' % os.getpid()): @wrap_exceptions def io_counters(self): fname = "/proc/%s/io" % self.pid with open(fname, 'rb') as f: rcount = wcount = rbytes = wbytes = None for line in f: if rcount is None and line.startswith(b"syscr"): rcount = int(line.split()[1]) elif wcount is None and line.startswith(b"syscw"): wcount = int(line.split()[1]) elif rbytes is None and line.startswith(b"read_bytes"): rbytes = int(line.split()[1]) elif wbytes is None and line.startswith(b"write_bytes"): wbytes = int(line.split()[1]) for x in (rcount, wcount, rbytes, wbytes): if x is None: raise NotImplementedError( "couldn't read all necessary info from %r" % fname) return _common.pio(rcount, wcount, rbytes, wbytes) else: def io_counters(self): raise NotImplementedError("couldn't find /proc/%s/io (kernel " "too old?)" % self.pid) @wrap_exceptions def cpu_times(self): with open("/proc/%s/stat" % self.pid, 'rb') as f: st = f.read().strip() # ignore the first two values ("pid (exe)") st = st[st.find(b')') + 2:] values = st.split(b' ') utime = float(values[11]) / CLOCK_TICKS stime = float(values[12]) / CLOCK_TICKS return _common.pcputimes(utime, stime) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) @wrap_exceptions def create_time(self): with open("/proc/%s/stat" % self.pid, 'rb') as f: st = f.read().strip() # ignore the first two values ("pid (exe)") st = st[st.rfind(b')') + 2:] values = st.split(b' ') # According to documentation, starttime is in field 21 and the # unit is jiffies (clock ticks). # We first divide it for clock ticks and then add uptime returning # seconds since the epoch, in UTC. # Also use cached value if available. bt = BOOT_TIME or boot_time() return (float(values[19]) / CLOCK_TICKS) + bt @wrap_exceptions def memory_info(self): with open("/proc/%s/statm" % self.pid, 'rb') as f: vms, rss = f.readline().split()[:2] return _common.pmem(int(rss) * PAGESIZE, int(vms) * PAGESIZE) @wrap_exceptions def memory_info_ex(self): # ============================================================ # | FIELD | DESCRIPTION | AKA | TOP | # ============================================================ # | rss | resident set size | | RES | # | vms | total program size | size | VIRT | # | shared | shared pages (from shared mappings) | | SHR | # | text | text ('code') | trs | CODE | # | lib | library (unused in Linux 2.6) | lrs | | # | data | data + stack | drs | DATA | # | dirty | dirty pages (unused in Linux 2.6) | dt | | # ============================================================ with open("/proc/%s/statm" % self.pid, "rb") as f: vms, rss, shared, text, lib, data, dirty = \ [int(x) * PAGESIZE for x in f.readline().split()[:7]] return pextmem(rss, vms, shared, text, lib, data, dirty) if os.path.exists('/proc/%s/smaps' % os.getpid()): @wrap_exceptions def memory_maps(self): """Return process's mapped memory regions as a list of nameduples. Fields are explained in 'man proc'; here is an updated (Apr 2012) version: http://goo.gl/fmebo """ with open("/proc/%s/smaps" % self.pid, "rt") as f: first_line = f.readline() current_block = [first_line] def get_blocks(): data = {} for line in f: fields = line.split(None, 5) if not fields[0].endswith(':'): # new block section yield (current_block.pop(), data) current_block.append(line) else: try: data[fields[0]] = int(fields[1]) * 1024 except ValueError: if fields[0].startswith('VmFlags:'): # see issue #369 continue else: raise ValueError("don't know how to inte" "rpret line %r" % line) yield (current_block.pop(), data) ls = [] if first_line: # smaps file can be empty for header, data in get_blocks(): hfields = header.split(None, 5) try: addr, perms, offset, dev, inode, path = hfields except ValueError: addr, perms, offset, dev, inode, path = \ hfields + [''] if not path: path = '[anon]' else: path = path.strip() ls.append(( addr, perms, path, data['Rss:'], data.get('Size:', 0), data.get('Pss:', 0), data.get('Shared_Clean:', 0), data.get('Shared_Dirty:', 0), data.get('Private_Clean:', 0), data.get('Private_Dirty:', 0), data.get('Referenced:', 0), data.get('Anonymous:', 0), data.get('Swap:', 0) )) return ls else: def memory_maps(self): msg = "couldn't find /proc/%s/smaps; kernel < 2.6.14 or " \ "CONFIG_MMU kernel configuration option is not enabled" \ % self.pid raise NotImplementedError(msg) @wrap_exceptions def cwd(self): # readlink() might return paths containing null bytes causing # problems when used with other fs-related functions (os.*, # open(), ...) path = os.readlink("/proc/%s/cwd" % self.pid) return path.replace('\x00', '') @wrap_exceptions def num_ctx_switches(self): vol = unvol = None with open("/proc/%s/status" % self.pid, "rb") as f: for line in f: if line.startswith(b"voluntary_ctxt_switches"): vol = int(line.split()[1]) elif line.startswith(b"nonvoluntary_ctxt_switches"): unvol = int(line.split()[1]) if vol is not None and unvol is not None: return _common.pctxsw(vol, unvol) raise NotImplementedError( "'voluntary_ctxt_switches' and 'nonvoluntary_ctxt_switches'" "fields were not found in /proc/%s/status; the kernel is " "probably older than 2.6.23" % self.pid) @wrap_exceptions def num_threads(self): with open("/proc/%s/status" % self.pid, "rb") as f: for line in f: if line.startswith(b"Threads:"): return int(line.split()[1]) raise NotImplementedError("line not found") @wrap_exceptions def threads(self): thread_ids = os.listdir("/proc/%s/task" % self.pid) thread_ids.sort() retlist = [] hit_enoent = False for thread_id in thread_ids: fname = "/proc/%s/task/%s/stat" % (self.pid, thread_id) try: with open(fname, 'rb') as f: st = f.read().strip() except EnvironmentError as err: if err.errno == errno.ENOENT: # no such file or directory; it means thread # disappeared on us hit_enoent = True continue raise # ignore the first two values ("pid (exe)") st = st[st.find(b')') + 2:] values = st.split(b' ') utime = float(values[11]) / CLOCK_TICKS stime = float(values[12]) / CLOCK_TICKS ntuple = _common.pthread(int(thread_id), utime, stime) retlist.append(ntuple) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def nice_get(self): # with open('/proc/%s/stat' % self.pid, 'r') as f: # data = f.read() # return int(data.split()[18]) # Use C implementation return _psutil_posix.getpriority(self.pid) @wrap_exceptions def nice_set(self, value): return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def cpu_affinity_get(self): return cext.proc_cpu_affinity_get(self.pid) @wrap_exceptions def cpu_affinity_set(self, cpus): try: cext.proc_cpu_affinity_set(self.pid, cpus) except OSError as err: if err.errno == errno.EINVAL: allcpus = tuple(range(len(per_cpu_times()))) for cpu in cpus: if cpu not in allcpus: raise ValueError("invalid CPU #%i (choose between %s)" % (cpu, allcpus)) raise # only starting from kernel 2.6.13 if hasattr(cext, "proc_ioprio_get"): @wrap_exceptions def ionice_get(self): ioclass, value = cext.proc_ioprio_get(self.pid) return _common.pionice(ioclass, value) @wrap_exceptions def ionice_set(self, ioclass, value): if ioclass in (IOPRIO_CLASS_NONE, None): if value: msg = "can't specify value with IOPRIO_CLASS_NONE" raise ValueError(msg) ioclass = IOPRIO_CLASS_NONE value = 0 if ioclass in (IOPRIO_CLASS_RT, IOPRIO_CLASS_BE): if value is None: value = 4 elif ioclass == IOPRIO_CLASS_IDLE: if value: msg = "can't specify value with IOPRIO_CLASS_IDLE" raise ValueError(msg) value = 0 else: value = 0 if not 0 <= value <= 8: raise ValueError( "value argument range expected is between 0 and 8") return cext.proc_ioprio_set(self.pid, ioclass, value) if HAS_PRLIMIT: @wrap_exceptions def rlimit(self, resource, limits=None): # if pid is 0 prlimit() applies to the calling process and # we don't want that if self.pid == 0: raise ValueError("can't use prlimit() against PID 0 process") if limits is None: # get return cext.linux_prlimit(self.pid, resource) else: # set if len(limits) != 2: raise ValueError( "second argument must be a (soft, hard) tuple") soft, hard = limits cext.linux_prlimit(self.pid, resource, soft, hard) @wrap_exceptions def status(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b"State:"): letter = line.split()[1] if PY3: letter = letter.decode() # XXX is '?' legit? (we're not supposed to return # it anyway) return PROC_STATUSES.get(letter, '?') @wrap_exceptions def open_files(self): retlist = [] files = os.listdir("/proc/%s/fd" % self.pid) hit_enoent = False for fd in files: file = "/proc/%s/fd/%s" % (self.pid, fd) try: file = os.readlink(file) except OSError as err: # ENOENT == file which is gone in the meantime if err.errno in (errno.ENOENT, errno.ESRCH): hit_enoent = True continue elif err.errno == errno.EINVAL: # not a link continue else: raise else: # If file is not an absolute path there's no way # to tell whether it's a regular file or not, # so we skip it. A regular file is always supposed # to be absolutized though. if file.startswith('/') and isfile_strict(file): ntuple = _common.popenfile(file, int(fd)) retlist.append(ntuple) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def connections(self, kind='inet'): ret = _connections.retrieve(kind, self.pid) # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return ret @wrap_exceptions def num_fds(self): return len(os.listdir("/proc/%s/fd" % self.pid)) @wrap_exceptions def ppid(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b"PPid:"): # PPid: nnnn return int(line.split()[1]) raise NotImplementedError("line not found") @wrap_exceptions def uids(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b'Uid:'): _, real, effective, saved, fs = line.split() return _common.puids(int(real), int(effective), int(saved)) raise NotImplementedError("line not found") @wrap_exceptions def gids(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b'Gid:'): _, real, effective, saved, fs = line.split() return _common.pgids(int(real), int(effective), int(saved)) raise NotImplementedError("line not found") ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_psosx.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """OSX platform implementation.""" import errno import functools import os from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import conn_tmap, usage_percent, isfile_strict import _psutil_osx as cext import _psutil_posix __extra__all__ = [] # --- constants PAGESIZE = os.sysconf("SC_PAGE_SIZE") # http://students.mimuw.edu.pl/lxr/source/include/net/tcp_states.h TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } PROC_STATUSES = { cext.SIDL: _common.STATUS_IDLE, cext.SRUN: _common.STATUS_RUNNING, cext.SSLEEP: _common.STATUS_SLEEPING, cext.SSTOP: _common.STATUS_STOPPED, cext.SZOMB: _common.STATUS_ZOMBIE, } scputimes = namedtuple('scputimes', ['user', 'nice', 'system', 'idle']) svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'wired']) pextmem = namedtuple('pextmem', ['rss', 'vms', 'pfaults', 'pageins']) pmmap_grouped = namedtuple( 'pmmap_grouped', 'path rss private swapped dirtied ref_count shadow_depth') pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- functions def virtual_memory(): """System virtual memory as a namedtuple.""" total, active, inactive, wired, free = cext.virtual_mem() avail = inactive + free used = active + inactive + wired percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, wired) def swap_memory(): """Swap system memory as a (total, used, free, sin, sout) tuple.""" total, used, free, sin, sout = cext.swap_mem() percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin, sout) def cpu_times(): """Return system CPU times as a namedtuple.""" user, nice, system, idle = cext.cpu_times() return scputimes(user, nice, system, idle) def per_cpu_times(): """Return system CPU times as a named tuple""" ret = [] for cpu_t in cext.per_cpu_times(): user, nice, system, idle = cpu_t item = scputimes(user, nice, system, idle) ret.append(item) return ret def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def disk_partitions(all=False): retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: if not os.path.isabs(device) or not os.path.exists(device): continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist def users(): retlist = [] rawlist = cext.users() for item in rawlist: user, tty, hostname, tstamp = item if tty == '~': continue # reboot or shutdown if not tstamp: continue nt = _common.suser(user, tty or None, hostname or None, tstamp) retlist.append(nt) return retlist def net_connections(kind='inet'): # Note: on OSX this will fail with AccessDenied unless # the process is owned by root. ret = [] for pid in pids(): try: cons = Process(pid).connections(kind) except NoSuchProcess: continue else: if cons: for c in cons: c = list(c) + [pid] ret.append(_common.sconn(*c)) return ret pids = cext.pids pid_exists = _psposix.pid_exists disk_usage = _psposix.disk_usage net_io_counters = cext.net_io_counters disk_io_counters = cext.disk_io_counters def wrap_exceptions(fun): """Decorator which translates bare OSError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except OSError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): return cext.proc_name(self.pid) @wrap_exceptions def exe(self): return cext.proc_exe(self.pid) @wrap_exceptions def cmdline(self): if not pid_exists(self.pid): raise NoSuchProcess(self.pid, self._name) return cext.proc_cmdline(self.pid) @wrap_exceptions def ppid(self): return cext.proc_ppid(self.pid) @wrap_exceptions def cwd(self): return cext.proc_cwd(self.pid) @wrap_exceptions def uids(self): real, effective, saved = cext.proc_uids(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def gids(self): real, effective, saved = cext.proc_gids(self.pid) return _common.pgids(real, effective, saved) @wrap_exceptions def terminal(self): tty_nr = cext.proc_tty_nr(self.pid) tmap = _psposix._get_terminal_map() try: return tmap[tty_nr] except KeyError: return None @wrap_exceptions def memory_info(self): rss, vms = cext.proc_memory_info(self.pid)[:2] return _common.pmem(rss, vms) @wrap_exceptions def memory_info_ex(self): rss, vms, pfaults, pageins = cext.proc_memory_info(self.pid) return pextmem(rss, vms, pfaults * PAGESIZE, pageins * PAGESIZE) @wrap_exceptions def cpu_times(self): user, system = cext.proc_cpu_times(self.pid) return _common.pcputimes(user, system) @wrap_exceptions def create_time(self): return cext.proc_create_time(self.pid) @wrap_exceptions def num_ctx_switches(self): return _common.pctxsw(*cext.proc_num_ctx_switches(self.pid)) @wrap_exceptions def num_threads(self): return cext.proc_num_threads(self.pid) @wrap_exceptions def open_files(self): if self.pid == 0: return [] files = [] rawlist = cext.proc_open_files(self.pid) for path, fd in rawlist: if isfile_strict(path): ntuple = _common.popenfile(path, fd) files.append(ntuple) return files @wrap_exceptions def connections(self, kind='inet'): if kind not in conn_tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in conn_tmap]))) families, types = conn_tmap[kind] rawlist = cext.proc_connections(self.pid, families, types) ret = [] for item in rawlist: fd, fam, type, laddr, raddr, status = item status = TCP_STATUSES[status] nt = _common.pconn(fd, fam, type, laddr, raddr, status) ret.append(nt) return ret @wrap_exceptions def num_fds(self): if self.pid == 0: return 0 return cext.proc_num_fds(self.pid) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) @wrap_exceptions def nice_get(self): return _psutil_posix.getpriority(self.pid) @wrap_exceptions def nice_set(self, value): return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def status(self): code = cext.proc_status(self.pid) # XXX is '?' legit? (we're not supposed to return it anyway) return PROC_STATUSES.get(code, '?') @wrap_exceptions def threads(self): rawlist = cext.proc_threads(self.pid) retlist = [] for thread_id, utime, stime in rawlist: ntuple = _common.pthread(thread_id, utime, stime) retlist.append(ntuple) return retlist @wrap_exceptions def memory_maps(self): return cext.proc_memory_maps(self.pid) ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_psposix.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Routines common to all posix systems.""" import errno import glob import os import sys import time from psutil._common import sdiskusage, usage_percent, memoize from psutil._compat import PY3, unicode class TimeoutExpired(Exception): pass def pid_exists(pid): """Check whether pid exists in the current process table.""" if pid == 0: # According to "man 2 kill" PID 0 has a special meaning: # it refers to <> so we don't want to go any further. # If we get here it means this UNIX platform *does* have # a process with id 0. return True try: os.kill(pid, 0) except OSError as err: if err.errno == errno.ESRCH: # ESRCH == No such process return False elif err.errno == errno.EPERM: # EPERM clearly means there's a process to deny access to return True else: # According to "man 2 kill" possible error values are # (EINVAL, EPERM, ESRCH) therefore we should never get # here. If we do let's be explicit in considering this # an error. raise err else: return True def wait_pid(pid, timeout=None): """Wait for process with pid 'pid' to terminate and return its exit status code as an integer. If pid is not a children of os.getpid() (current process) just waits until the process disappears and return None. If pid does not exist at all return None immediately. Raise TimeoutExpired on timeout expired. """ def check_timeout(delay): if timeout is not None: if timer() >= stop_at: raise TimeoutExpired() time.sleep(delay) return min(delay * 2, 0.04) timer = getattr(time, 'monotonic', time.time) if timeout is not None: waitcall = lambda: os.waitpid(pid, os.WNOHANG) stop_at = timer() + timeout else: waitcall = lambda: os.waitpid(pid, 0) delay = 0.0001 while True: try: retpid, status = waitcall() except OSError as err: if err.errno == errno.EINTR: delay = check_timeout(delay) continue elif err.errno == errno.ECHILD: # This has two meanings: # - pid is not a child of os.getpid() in which case # we keep polling until it's gone # - pid never existed in the first place # In both cases we'll eventually return None as we # can't determine its exit status code. while True: if pid_exists(pid): delay = check_timeout(delay) else: return else: raise else: if retpid == 0: # WNOHANG was used, pid is still running delay = check_timeout(delay) continue # process exited due to a signal; return the integer of # that signal if os.WIFSIGNALED(status): return os.WTERMSIG(status) # process exited using exit(2) system call; return the # integer exit(2) system call has been called with elif os.WIFEXITED(status): return os.WEXITSTATUS(status) else: # should never happen raise RuntimeError("unknown process exit status") def disk_usage(path): """Return disk usage associated with path.""" try: st = os.statvfs(path) except UnicodeEncodeError: if not PY3 and isinstance(path, unicode): # this is a bug with os.statvfs() and unicode on # Python 2, see: # - https://github.com/giampaolo/psutil/issues/416 # - http://bugs.python.org/issue18695 try: path = path.encode(sys.getfilesystemencoding()) except UnicodeEncodeError: pass st = os.statvfs(path) else: raise free = (st.f_bavail * st.f_frsize) total = (st.f_blocks * st.f_frsize) used = (st.f_blocks - st.f_bfree) * st.f_frsize percent = usage_percent(used, total, _round=1) # NB: the percentage is -5% than what shown by df due to # reserved blocks that we are currently not considering: # http://goo.gl/sWGbH return sdiskusage(total, used, free, percent) @memoize def _get_terminal_map(): ret = {} ls = glob.glob('/dev/tty*') + glob.glob('/dev/pts/*') for name in ls: assert name not in ret try: ret[os.stat(name).st_rdev] = name except OSError as err: if err.errno != errno.ENOENT: raise return ret ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_pssunos.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Sun OS Solaris platform implementation.""" import errno import os import socket import subprocess import sys from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import usage_percent, isfile_strict from psutil._compat import PY3 import _psutil_posix import _psutil_sunos as cext __extra__all__ = ["CONN_IDLE", "CONN_BOUND"] PAGE_SIZE = os.sysconf('SC_PAGE_SIZE') CONN_IDLE = "IDLE" CONN_BOUND = "BOUND" PROC_STATUSES = { cext.SSLEEP: _common.STATUS_SLEEPING, cext.SRUN: _common.STATUS_RUNNING, cext.SZOMB: _common.STATUS_ZOMBIE, cext.SSTOP: _common.STATUS_STOPPED, cext.SIDL: _common.STATUS_IDLE, cext.SONPROC: _common.STATUS_RUNNING, # same as run cext.SWAIT: _common.STATUS_WAITING, } TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RCVD: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, cext.TCPS_IDLE: CONN_IDLE, # sunos specific cext.TCPS_BOUND: CONN_BOUND, # sunos specific } scputimes = namedtuple('scputimes', ['user', 'system', 'idle', 'iowait']) svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free']) pextmem = namedtuple('pextmem', ['rss', 'vms']) pmmap_grouped = namedtuple('pmmap_grouped', ['path', 'rss', 'anon', 'locked']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- functions disk_io_counters = cext.disk_io_counters net_io_counters = cext.net_io_counters disk_usage = _psposix.disk_usage def virtual_memory(): # we could have done this with kstat, but imho this is good enough total = os.sysconf('SC_PHYS_PAGES') * PAGE_SIZE # note: there's no difference on Solaris free = avail = os.sysconf('SC_AVPHYS_PAGES') * PAGE_SIZE used = total - free percent = usage_percent(used, total, _round=1) return svmem(total, avail, percent, used, free) def swap_memory(): sin, sout = cext.swap_mem() # XXX # we are supposed to get total/free by doing so: # http://cvs.opensolaris.org/source/xref/onnv/onnv-gate/ # usr/src/cmd/swap/swap.c # ...nevertheless I can't manage to obtain the same numbers as 'swap' # cmdline utility, so let's parse its output (sigh!) p = subprocess.Popen(['swap', '-l', '-k'], stdout=subprocess.PIPE) stdout, stderr = p.communicate() if PY3: stdout = stdout.decode(sys.stdout.encoding) if p.returncode != 0: raise RuntimeError("'swap -l -k' failed (retcode=%s)" % p.returncode) lines = stdout.strip().split('\n')[1:] if not lines: raise RuntimeError('no swap device(s) configured') total = free = 0 for line in lines: line = line.split() t, f = line[-2:] t = t.replace('K', '') f = f.replace('K', '') total += int(int(t) * 1024) free += int(int(f) * 1024) used = total - free percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin * PAGE_SIZE, sout * PAGE_SIZE) def pids(): """Returns a list of PIDs currently running on the system.""" return [int(x) for x in os.listdir('/proc') if x.isdigit()] def pid_exists(pid): """Check for the existence of a unix pid.""" return _psposix.pid_exists(pid) def cpu_times(): """Return system-wide CPU times as a named tuple""" ret = cext.per_cpu_times() return scputimes(*[sum(x) for x in zip(*ret)]) def per_cpu_times(): """Return system per-CPU times as a list of named tuples""" ret = cext.per_cpu_times() return [scputimes(*x) for x in ret] def cpu_count_logical(): """Return the number of logical CPUs in the system.""" try: return os.sysconf("SC_NPROCESSORS_ONLN") except ValueError: # mimic os.cpu_count() behavior return None def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() localhost = (':0.0', ':0') for item in rawlist: user, tty, hostname, tstamp, user_process = item # note: the underlying C function includes entries about # system boot, run level and others. We might want # to use them in the future. if not user_process: continue if hostname in localhost: hostname = 'localhost' nt = _common.suser(user, tty, hostname, tstamp) retlist.append(nt) return retlist def disk_partitions(all=False): """Return system disk partitions.""" # TODO - the filtering logic should be better checked so that # it tries to reflect 'df' as much as possible retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: # Differently from, say, Linux, we don't have a list of # common fs types so the best we can do, AFAIK, is to # filter by filesystem having a total size > 0. if not disk_usage(mountpoint).total: continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist def net_connections(kind, _pid=-1): """Return socket connections. If pid == -1 return system-wide connections (as opposed to connections opened by one process only). Only INET sockets are returned (UNIX are not). """ cmap = _common.conn_tmap.copy() if _pid == -1: cmap.pop('unix', 0) if kind not in cmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in cmap]))) families, types = _common.conn_tmap[kind] rawlist = cext.net_connections(_pid, families, types) ret = [] for item in rawlist: fd, fam, type_, laddr, raddr, status, pid = item if fam not in families: continue if type_ not in types: continue status = TCP_STATUSES[status] if _pid == -1: nt = _common.sconn(fd, fam, type_, laddr, raddr, status, pid) else: nt = _common.pconn(fd, fam, type_, laddr, raddr, status) ret.append(nt) return ret def wrap_exceptions(fun): """Call callable into a try/except clause and translate ENOENT, EACCES and EPERM in NoSuchProcess or AccessDenied exceptions. """ def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except EnvironmentError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise # ENOENT (no such file or directory) gets raised on open(). # ESRCH (no such process) can get raised on read() if # process is gone in meantime. if err.errno in (errno.ENOENT, errno.ESRCH): raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): # note: max len == 15 return cext.proc_name_and_args(self.pid)[0] @wrap_exceptions def exe(self): # Will be guess later from cmdline but we want to explicitly # invoke cmdline here in order to get an AccessDenied # exception if the user has not enough privileges. self.cmdline() return "" @wrap_exceptions def cmdline(self): return cext.proc_name_and_args(self.pid)[1].split(' ') @wrap_exceptions def create_time(self): return cext.proc_basic_info(self.pid)[3] @wrap_exceptions def num_threads(self): return cext.proc_basic_info(self.pid)[5] @wrap_exceptions def nice_get(self): # For some reason getpriority(3) return ESRCH (no such process) # for certain low-pid processes, no matter what (even as root). # The process actually exists though, as it has a name, # creation time, etc. # The best thing we can do here appears to be raising AD. # Note: tested on Solaris 11; on Open Solaris 5 everything is # fine. try: return _psutil_posix.getpriority(self.pid) except EnvironmentError as err: if err.errno in (errno.ENOENT, errno.ESRCH): if pid_exists(self.pid): raise AccessDenied(self.pid, self._name) raise @wrap_exceptions def nice_set(self, value): if self.pid in (2, 3): # Special case PIDs: internally setpriority(3) return ESRCH # (no such process), no matter what. # The process actually exists though, as it has a name, # creation time, etc. raise AccessDenied(self.pid, self._name) return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def ppid(self): return cext.proc_basic_info(self.pid)[0] @wrap_exceptions def uids(self): real, effective, saved, _, _, _ = cext.proc_cred(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def gids(self): _, _, _, real, effective, saved = cext.proc_cred(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def cpu_times(self): user, system = cext.proc_cpu_times(self.pid) return _common.pcputimes(user, system) @wrap_exceptions def terminal(self): hit_enoent = False tty = wrap_exceptions( cext.proc_basic_info(self.pid)[0]) if tty != cext.PRNODEV: for x in (0, 1, 2, 255): try: return os.readlink('/proc/%d/path/%d' % (self.pid, x)) except OSError as err: if err.errno == errno.ENOENT: hit_enoent = True continue raise if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) @wrap_exceptions def cwd(self): # /proc/PID/path/cwd may not be resolved by readlink() even if # it exists (ls shows it). If that's the case and the process # is still alive return None (we can return None also on BSD). # Reference: http://goo.gl/55XgO try: return os.readlink("/proc/%s/path/cwd" % self.pid) except OSError as err: if err.errno == errno.ENOENT: os.stat("/proc/%s" % self.pid) return None raise @wrap_exceptions def memory_info(self): ret = cext.proc_basic_info(self.pid) rss, vms = ret[1] * 1024, ret[2] * 1024 return _common.pmem(rss, vms) # it seems Solaris uses rss and vms only memory_info_ex = memory_info @wrap_exceptions def status(self): code = cext.proc_basic_info(self.pid)[6] # XXX is '?' legit? (we're not supposed to return it anyway) return PROC_STATUSES.get(code, '?') @wrap_exceptions def threads(self): ret = [] tids = os.listdir('/proc/%d/lwp' % self.pid) hit_enoent = False for tid in tids: tid = int(tid) try: utime, stime = cext.query_process_thread( self.pid, tid) except EnvironmentError as err: # ENOENT == thread gone in meantime if err.errno == errno.ENOENT: hit_enoent = True continue raise else: nt = _common.pthread(tid, utime, stime) ret.append(nt) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return ret @wrap_exceptions def open_files(self): retlist = [] hit_enoent = False pathdir = '/proc/%d/path' % self.pid for fd in os.listdir('/proc/%d/fd' % self.pid): path = os.path.join(pathdir, fd) if os.path.islink(path): try: file = os.readlink(path) except OSError as err: # ENOENT == file which is gone in the meantime if err.errno == errno.ENOENT: hit_enoent = True continue raise else: if isfile_strict(file): retlist.append(_common.popenfile(file, int(fd))) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist def _get_unix_sockets(self, pid): """Get UNIX sockets used by process by parsing 'pfiles' output.""" # TODO: rewrite this in C (...but the damn netstat source code # does not include this part! Argh!!) cmd = "pfiles %s" % pid p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = p.communicate() if PY3: stdout, stderr = [x.decode(sys.stdout.encoding) for x in (stdout, stderr)] if p.returncode != 0: if 'permission denied' in stderr.lower(): raise AccessDenied(self.pid, self._name) if 'no such process' in stderr.lower(): raise NoSuchProcess(self.pid, self._name) raise RuntimeError("%r command error\n%s" % (cmd, stderr)) lines = stdout.split('\n')[2:] for i, line in enumerate(lines): line = line.lstrip() if line.startswith('sockname: AF_UNIX'): path = line.split(' ', 2)[2] type = lines[i - 2].strip() if type == 'SOCK_STREAM': type = socket.SOCK_STREAM elif type == 'SOCK_DGRAM': type = socket.SOCK_DGRAM else: type = -1 yield (-1, socket.AF_UNIX, type, path, "", _common.CONN_NONE) @wrap_exceptions def connections(self, kind='inet'): ret = net_connections(kind, _pid=self.pid) # The underlying C implementation retrieves all OS connections # and filters them by PID. At this point we can't tell whether # an empty list means there were no connections for process or # process is no longer active so we force NSP in case the PID # is no longer there. if not ret: os.stat('/proc/%s' % self.pid) # will raise NSP if process is gone # UNIX sockets if kind in ('all', 'unix'): ret.extend([_common.pconn(*conn) for conn in self._get_unix_sockets(self.pid)]) return ret nt_mmap_grouped = namedtuple('mmap', 'path rss anon locked') nt_mmap_ext = namedtuple('mmap', 'addr perms path rss anon locked') @wrap_exceptions def memory_maps(self): def toaddr(start, end): return '%s-%s' % (hex(start)[2:].strip('L'), hex(end)[2:].strip('L')) retlist = [] rawlist = cext.proc_memory_maps(self.pid) hit_enoent = False for item in rawlist: addr, addrsize, perm, name, rss, anon, locked = item addr = toaddr(addr, addrsize) if not name.startswith('['): try: name = os.readlink('/proc/%s/path/%s' % (self.pid, name)) except OSError as err: if err.errno == errno.ENOENT: # sometimes the link may not be resolved by # readlink() even if it exists (ls shows it). # If that's the case we just return the # unresolved link path. # This seems an incosistency with /proc similar # to: http://goo.gl/55XgO name = '/proc/%s/path/%s' % (self.pid, name) hit_enoent = True else: raise retlist.append((addr, perm, name, rss, anon, locked)) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def num_fds(self): return len(os.listdir("/proc/%s/fd" % self.pid)) @wrap_exceptions def num_ctx_switches(self): return _common.pctxsw(*cext.proc_num_ctx_switches(self.pid)) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) ================================================ FILE: Common/libpsutil/py2.6-glibc-2.12-pre/psutil/_pswindows.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Windows platform implementation.""" import errno import functools import os from collections import namedtuple from psutil import _common from psutil._common import conn_tmap, usage_percent, isfile_strict from psutil._compat import PY3, xrange, lru_cache import _psutil_windows as cext # process priority constants, import from __init__.py: # http://msdn.microsoft.com/en-us/library/ms686219(v=vs.85).aspx __extra__all__ = ["ABOVE_NORMAL_PRIORITY_CLASS", "BELOW_NORMAL_PRIORITY_CLASS", "HIGH_PRIORITY_CLASS", "IDLE_PRIORITY_CLASS", "NORMAL_PRIORITY_CLASS", "REALTIME_PRIORITY_CLASS", # "CONN_DELETE_TCB", ] # --- module level constants (gets pushed up to psutil module) CONN_DELETE_TCB = "DELETE_TCB" WAIT_TIMEOUT = 0x00000102 # 258 in decimal ACCESS_DENIED_SET = frozenset([errno.EPERM, errno.EACCES, cext.ERROR_ACCESS_DENIED]) TCP_STATUSES = { cext.MIB_TCP_STATE_ESTAB: _common.CONN_ESTABLISHED, cext.MIB_TCP_STATE_SYN_SENT: _common.CONN_SYN_SENT, cext.MIB_TCP_STATE_SYN_RCVD: _common.CONN_SYN_RECV, cext.MIB_TCP_STATE_FIN_WAIT1: _common.CONN_FIN_WAIT1, cext.MIB_TCP_STATE_FIN_WAIT2: _common.CONN_FIN_WAIT2, cext.MIB_TCP_STATE_TIME_WAIT: _common.CONN_TIME_WAIT, cext.MIB_TCP_STATE_CLOSED: _common.CONN_CLOSE, cext.MIB_TCP_STATE_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.MIB_TCP_STATE_LAST_ACK: _common.CONN_LAST_ACK, cext.MIB_TCP_STATE_LISTEN: _common.CONN_LISTEN, cext.MIB_TCP_STATE_CLOSING: _common.CONN_CLOSING, cext.MIB_TCP_STATE_DELETE_TCB: CONN_DELETE_TCB, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } scputimes = namedtuple('scputimes', ['user', 'system', 'idle']) svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free']) pextmem = namedtuple( 'pextmem', ['num_page_faults', 'peak_wset', 'wset', 'peak_paged_pool', 'paged_pool', 'peak_nonpaged_pool', 'nonpaged_pool', 'pagefile', 'peak_pagefile', 'private']) pmmap_grouped = namedtuple('pmmap_grouped', ['path', 'rss']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None @lru_cache(maxsize=512) def _win32_QueryDosDevice(s): return cext.win32_QueryDosDevice(s) def _convert_raw_path(s): # convert paths using native DOS format like: # "\Device\HarddiskVolume1\Windows\systemew\file.txt" # into: "C:\Windows\systemew\file.txt" if PY3 and not isinstance(s, str): s = s.decode('utf8') rawdrive = '\\'.join(s.split('\\')[:3]) driveletter = _win32_QueryDosDevice(rawdrive) return os.path.join(driveletter, s[len(rawdrive):]) # --- public functions def virtual_memory(): """System virtual memory as a namedtuple.""" mem = cext.virtual_mem() totphys, availphys, totpagef, availpagef, totvirt, freevirt = mem # total = totphys avail = availphys free = availphys used = total - avail percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free) def swap_memory(): """Swap system memory as a (total, used, free, sin, sout) tuple.""" mem = cext.virtual_mem() total = mem[2] free = mem[3] used = total - free percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, 0, 0) def disk_usage(path): """Return disk usage associated with path.""" try: total, free = cext.disk_usage(path) except WindowsError: if not os.path.exists(path): msg = "No such file or directory: '%s'" % path raise OSError(errno.ENOENT, msg) raise used = total - free percent = usage_percent(used, total, _round=1) return _common.sdiskusage(total, used, free, percent) def disk_partitions(all): """Return disk partitions.""" rawlist = cext.disk_partitions(all) return [_common.sdiskpart(*x) for x in rawlist] def cpu_times(): """Return system CPU times as a named tuple.""" user, system, idle = cext.cpu_times() return scputimes(user, system, idle) def per_cpu_times(): """Return system per-CPU times as a list of named tuples.""" ret = [] for cpu_t in cext.per_cpu_times(): user, system, idle = cpu_t item = scputimes(user, system, idle) ret.append(item) return ret def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def net_connections(kind, _pid=-1): """Return socket connections. If pid == -1 return system-wide connections (as opposed to connections opened by one process only). """ if kind not in conn_tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in conn_tmap]))) families, types = conn_tmap[kind] rawlist = cext.net_connections(_pid, families, types) ret = [] for item in rawlist: fd, fam, type, laddr, raddr, status, pid = item status = TCP_STATUSES[status] if _pid == -1: nt = _common.sconn(fd, fam, type, laddr, raddr, status, pid) else: nt = _common.pconn(fd, fam, type, laddr, raddr, status) ret.append(nt) return ret def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() for item in rawlist: user, hostname, tstamp = item nt = _common.suser(user, None, hostname, tstamp) retlist.append(nt) return retlist pids = cext.pids pid_exists = cext.pid_exists net_io_counters = cext.net_io_counters disk_io_counters = cext.disk_io_counters ppid_map = cext.ppid_map # not meant to be public def wrap_exceptions(fun): """Decorator which translates bare OSError and WindowsError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except OSError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise if err.errno in ACCESS_DENIED_SET: raise AccessDenied(self.pid, self._name) if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): """Return process name, which on Windows is always the final part of the executable. """ # This is how PIDs 0 and 4 are always represented in taskmgr # and process-hacker. if self.pid == 0: return "System Idle Process" elif self.pid == 4: return "System" else: return os.path.basename(self.exe()) @wrap_exceptions def exe(self): # Note: os.path.exists(path) may return False even if the file # is there, see: # http://stackoverflow.com/questions/3112546/os-path-exists-lies # see https://github.com/giampaolo/psutil/issues/414 # see https://github.com/giampaolo/psutil/issues/528 if self.pid in (0, 4): raise AccessDenied(self.pid, self._name) return _convert_raw_path(cext.proc_exe(self.pid)) @wrap_exceptions def cmdline(self): return cext.proc_cmdline(self.pid) def ppid(self): try: return ppid_map()[self.pid] except KeyError: raise NoSuchProcess(self.pid, self._name) def _get_raw_meminfo(self): try: return cext.proc_memory_info(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_memory_info_2(self.pid) raise @wrap_exceptions def memory_info(self): # on Windows RSS == WorkingSetSize and VSM == PagefileUsage # fields of PROCESS_MEMORY_COUNTERS struct: # http://msdn.microsoft.com/en-us/library/windows/desktop/ # ms684877(v=vs.85).aspx t = self._get_raw_meminfo() return _common.pmem(t[2], t[7]) @wrap_exceptions def memory_info_ex(self): return pextmem(*self._get_raw_meminfo()) def memory_maps(self): try: raw = cext.proc_memory_maps(self.pid) except OSError as err: # XXX - can't use wrap_exceptions decorator as we're # returning a generator; probably needs refactoring. if err.errno in ACCESS_DENIED_SET: raise AccessDenied(self.pid, self._name) if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) raise else: for addr, perm, path, rss in raw: path = _convert_raw_path(path) addr = hex(addr) yield (addr, perm, path, rss) @wrap_exceptions def kill(self): return cext.proc_kill(self.pid) @wrap_exceptions def wait(self, timeout=None): if timeout is None: timeout = cext.INFINITE else: # WaitForSingleObject() expects time in milliseconds timeout = int(timeout * 1000) ret = cext.proc_wait(self.pid, timeout) if ret == WAIT_TIMEOUT: # support for private module import if TimeoutExpired is None: raise RuntimeError("timeout expired") raise TimeoutExpired(timeout, self.pid, self._name) return ret @wrap_exceptions def username(self): if self.pid in (0, 4): return 'NT AUTHORITY\\SYSTEM' return cext.proc_username(self.pid) @wrap_exceptions def create_time(self): # special case for kernel process PIDs; return system boot time if self.pid in (0, 4): return boot_time() try: return cext.proc_create_time(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_create_time_2(self.pid) raise @wrap_exceptions def num_threads(self): return cext.proc_num_threads(self.pid) @wrap_exceptions def threads(self): rawlist = cext.proc_threads(self.pid) retlist = [] for thread_id, utime, stime in rawlist: ntuple = _common.pthread(thread_id, utime, stime) retlist.append(ntuple) return retlist @wrap_exceptions def cpu_times(self): try: ret = cext.proc_cpu_times(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: ret = cext.proc_cpu_times_2(self.pid) else: raise return _common.pcputimes(*ret) @wrap_exceptions def suspend(self): return cext.proc_suspend(self.pid) @wrap_exceptions def resume(self): return cext.proc_resume(self.pid) @wrap_exceptions def cwd(self): if self.pid in (0, 4): raise AccessDenied(self.pid, self._name) # return a normalized pathname since the native C function appends # "\\" at the and of the path path = cext.proc_cwd(self.pid) return os.path.normpath(path) @wrap_exceptions def open_files(self): if self.pid in (0, 4): return [] retlist = [] # Filenames come in in native format like: # "\Device\HarddiskVolume1\Windows\systemew\file.txt" # Convert the first part in the corresponding drive letter # (e.g. "C:\") by using Windows's QueryDosDevice() raw_file_names = cext.proc_open_files(self.pid) for file in raw_file_names: file = _convert_raw_path(file) if isfile_strict(file) and file not in retlist: ntuple = _common.popenfile(file, -1) retlist.append(ntuple) return retlist @wrap_exceptions def connections(self, kind='inet'): return net_connections(kind, _pid=self.pid) @wrap_exceptions def nice_get(self): return cext.proc_priority_get(self.pid) @wrap_exceptions def nice_set(self, value): return cext.proc_priority_set(self.pid, value) # available on Windows >= Vista if hasattr(cext, "proc_io_priority_get"): @wrap_exceptions def ionice_get(self): return cext.proc_io_priority_get(self.pid) @wrap_exceptions def ionice_set(self, value, _): if _: raise TypeError("set_proc_ionice() on Windows takes only " "1 argument (2 given)") if value not in (2, 1, 0): raise ValueError("value must be 2 (normal), 1 (low) or 0 " "(very low); got %r" % value) return cext.proc_io_priority_set(self.pid, value) @wrap_exceptions def io_counters(self): try: ret = cext.proc_io_counters(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: ret = cext.proc_io_counters_2(self.pid) else: raise return _common.pio(*ret) @wrap_exceptions def status(self): suspended = cext.proc_is_suspended(self.pid) if suspended: return _common.STATUS_STOPPED else: return _common.STATUS_RUNNING @wrap_exceptions def cpu_affinity_get(self): from_bitmask = lambda x: [i for i in xrange(64) if (1 << i) & x] bitmask = cext.proc_cpu_affinity_get(self.pid) return from_bitmask(bitmask) @wrap_exceptions def cpu_affinity_set(self, value): def to_bitmask(l): if not l: raise ValueError("invalid argument %r" % l) out = 0 for b in l: out |= 2 ** b return out # SetProcessAffinityMask() states that ERROR_INVALID_PARAMETER # is returned for an invalid CPU but this seems not to be true, # therefore we check CPUs validy beforehand. allcpus = list(range(len(per_cpu_times()))) for cpu in value: if cpu not in allcpus: raise ValueError("invalid CPU %r" % cpu) bitmask = to_bitmask(value) cext.proc_cpu_affinity_set(self.pid, bitmask) @wrap_exceptions def num_handles(self): try: return cext.proc_num_handles(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_num_handles_2(self.pid) raise @wrap_exceptions def num_ctx_switches(self): tupl = cext.proc_num_ctx_switches(self.pid) return _common.pctxsw(*tupl) ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/__init__.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """psutil is a cross-platform library for retrieving information on running processes and system utilization (CPU, memory, disks, network) in Python. """ from __future__ import division __author__ = "Giampaolo Rodola'" __version__ = "2.2.0" version_info = tuple([int(num) for num in __version__.split('.')]) __all__ = [ # exceptions "Error", "NoSuchProcess", "AccessDenied", "TimeoutExpired", # constants "version_info", "__version__", "STATUS_RUNNING", "STATUS_IDLE", "STATUS_SLEEPING", "STATUS_DISK_SLEEP", "STATUS_STOPPED", "STATUS_TRACING_STOP", "STATUS_ZOMBIE", "STATUS_DEAD", "STATUS_WAKING", "STATUS_LOCKED", "STATUS_WAITING", "STATUS_LOCKED", "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1", "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT", "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", "CONN_NONE", # classes "Process", "Popen", # functions "pid_exists", "pids", "process_iter", "wait_procs", # proc "virtual_memory", "swap_memory", # memory "cpu_times", "cpu_percent", "cpu_times_percent", "cpu_count", # cpu "net_io_counters", "net_connections", # network "disk_io_counters", "disk_partitions", "disk_usage", # disk "users", "boot_time", # others ] import collections import errno import functools import os import signal import subprocess import sys import time import warnings try: import pwd except ImportError: pwd = None from psutil._common import memoize from psutil._compat import callable, long from psutil._compat import PY3 as _PY3 from psutil._common import (deprecated_method as _deprecated_method, deprecated as _deprecated, sdiskio as _nt_sys_diskio, snetio as _nt_sys_netio) from psutil._common import (STATUS_RUNNING, # NOQA STATUS_SLEEPING, STATUS_DISK_SLEEP, STATUS_STOPPED, STATUS_TRACING_STOP, STATUS_ZOMBIE, STATUS_DEAD, STATUS_WAKING, STATUS_LOCKED, STATUS_IDLE, # bsd STATUS_WAITING, # bsd STATUS_LOCKED) # bsd from psutil._common import (CONN_ESTABLISHED, CONN_SYN_SENT, CONN_SYN_RECV, CONN_FIN_WAIT1, CONN_FIN_WAIT2, CONN_TIME_WAIT, CONN_CLOSE, CONN_CLOSE_WAIT, CONN_LAST_ACK, CONN_LISTEN, CONN_CLOSING, CONN_NONE) if sys.platform.startswith("linux"): import psutil._pslinux as _psplatform from psutil._pslinux import (phymem_buffers, # NOQA cached_phymem) from psutil._pslinux import (IOPRIO_CLASS_NONE, # NOQA IOPRIO_CLASS_RT, IOPRIO_CLASS_BE, IOPRIO_CLASS_IDLE) # Linux >= 2.6.36 if _psplatform.HAS_PRLIMIT: from _psutil_linux import (RLIM_INFINITY, # NOQA RLIMIT_AS, RLIMIT_CORE, RLIMIT_CPU, RLIMIT_DATA, RLIMIT_FSIZE, RLIMIT_LOCKS, RLIMIT_MEMLOCK, RLIMIT_NOFILE, RLIMIT_NPROC, RLIMIT_RSS, RLIMIT_STACK) # Kinda ugly but considerably faster than using hasattr() and # setattr() against the module object (we are at import time: # speed matters). import _psutil_linux try: RLIMIT_MSGQUEUE = _psutil_linux.RLIMIT_MSGQUEUE except AttributeError: pass try: RLIMIT_NICE = _psutil_linux.RLIMIT_NICE except AttributeError: pass try: RLIMIT_RTPRIO = _psutil_linux.RLIMIT_RTPRIO except AttributeError: pass try: RLIMIT_RTTIME = _psutil_linux.RLIMIT_RTTIME except AttributeError: pass try: RLIMIT_SIGPENDING = _psutil_linux.RLIMIT_SIGPENDING except AttributeError: pass del _psutil_linux elif sys.platform.startswith("win32"): import psutil._pswindows as _psplatform from _psutil_windows import (ABOVE_NORMAL_PRIORITY_CLASS, # NOQA BELOW_NORMAL_PRIORITY_CLASS, HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, NORMAL_PRIORITY_CLASS, REALTIME_PRIORITY_CLASS) from psutil._pswindows import CONN_DELETE_TCB # NOQA elif sys.platform.startswith("darwin"): import psutil._psosx as _psplatform elif sys.platform.startswith("freebsd"): import psutil._psbsd as _psplatform elif sys.platform.startswith("sunos"): import psutil._pssunos as _psplatform from psutil._pssunos import (CONN_IDLE, # NOQA CONN_BOUND) else: raise NotImplementedError('platform %s is not supported' % sys.platform) __all__.extend(_psplatform.__extra__all__) _TOTAL_PHYMEM = None _POSIX = os.name == 'posix' _WINDOWS = os.name == 'nt' _timer = getattr(time, 'monotonic', time.time) # Sanity check in case the user messed up with psutil installation # or did something weird with sys.path. In this case we might end # up importing a python module using a C extension module which # was compiled for a different version of psutil. # We want to prevent that by failing sooner rather than later. # See: https://github.com/giampaolo/psutil/issues/564 if (int(__version__.replace('.', '')) != getattr(_psplatform.cext, 'version', None)): msg = "version conflict: %r C extension module was built for another " \ "version of psutil (different than %s)" % (_psplatform.cext.__file__, __version__) raise ImportError(msg) # ===================================================================== # --- exceptions # ===================================================================== class Error(Exception): """Base exception class. All other psutil exceptions inherit from this one. """ class NoSuchProcess(Error): """Exception raised when a process with a certain PID doesn't or no longer exists (zombie). """ def __init__(self, pid, name=None, msg=None): Error.__init__(self) self.pid = pid self.name = name self.msg = msg if msg is None: if name: details = "(pid=%s, name=%s)" % (self.pid, repr(self.name)) else: details = "(pid=%s)" % self.pid self.msg = "process no longer exists " + details def __str__(self): return self.msg class AccessDenied(Error): """Exception raised when permission to perform an action is denied.""" def __init__(self, pid=None, name=None, msg=None): Error.__init__(self) self.pid = pid self.name = name self.msg = msg if msg is None: if (pid is not None) and (name is not None): self.msg = "(pid=%s, name=%s)" % (pid, repr(name)) elif (pid is not None): self.msg = "(pid=%s)" % self.pid else: self.msg = "" def __str__(self): return self.msg class TimeoutExpired(Error): """Raised on Process.wait(timeout) if timeout expires and process is still alive. """ def __init__(self, seconds, pid=None, name=None): Error.__init__(self) self.seconds = seconds self.pid = pid self.name = name self.msg = "timeout after %s seconds" % seconds if (pid is not None) and (name is not None): self.msg += " (pid=%s, name=%s)" % (pid, repr(name)) elif (pid is not None): self.msg += " (pid=%s)" % self.pid def __str__(self): return self.msg # push exception classes into platform specific module namespace _psplatform.NoSuchProcess = NoSuchProcess _psplatform.AccessDenied = AccessDenied _psplatform.TimeoutExpired = TimeoutExpired # ===================================================================== # --- Process class # ===================================================================== def _assert_pid_not_reused(fun): """Decorator which raises NoSuchProcess in case a process is no longer running or its PID has been reused. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): if not self.is_running(): raise NoSuchProcess(self.pid, self._name) return fun(self, *args, **kwargs) return wrapper class Process(object): """Represents an OS process with the given PID. If PID is omitted current process PID (os.getpid()) is used. Raise NoSuchProcess if PID does not exist. Note that most of the methods of this class do not make sure the PID of the process being queried has been reused over time. That means you might end up retrieving an information referring to another process in case the original one this instance refers to is gone in the meantime. The only exceptions for which process identity is pre-emptively checked and guaranteed are: - parent() - children() - nice() (set) - ionice() (set) - rlimit() (set) - cpu_affinity (set) - suspend() - resume() - send_signal() - terminate() - kill() To prevent this problem for all other methods you can: - use is_running() before querying the process - if you're continuously iterating over a set of Process instances use process_iter() which pre-emptively checks process identity for every yielded instance """ def __init__(self, pid=None): self._init(pid) def _init(self, pid, _ignore_nsp=False): if pid is None: pid = os.getpid() else: if not _PY3 and not isinstance(pid, (int, long)): raise TypeError('pid must be an integer (got %r)' % pid) if pid < 0: raise ValueError('pid must be a positive integer (got %s)' % pid) self._pid = pid self._name = None self._exe = None self._create_time = None self._gone = False self._hash = None # used for caching on Windows only (on POSIX ppid may change) self._ppid = None # platform-specific modules define an _psplatform.Process # implementation class self._proc = _psplatform.Process(pid) self._last_sys_cpu_times = None self._last_proc_cpu_times = None # cache creation time for later use in is_running() method try: self.create_time() except AccessDenied: # we should never get here as AFAIK we're able to get # process creation time on all platforms even as a # limited user pass except NoSuchProcess: if not _ignore_nsp: msg = 'no process found with pid %s' % pid raise NoSuchProcess(pid, None, msg) else: self._gone = True # This pair is supposed to indentify a Process instance # univocally over time (the PID alone is not enough as # it might refer to a process whose PID has been reused). # This will be used later in __eq__() and is_running(). self._ident = (self.pid, self._create_time) def __str__(self): try: pid = self.pid name = repr(self.name()) except NoSuchProcess: details = "(pid=%s (terminated))" % self.pid except AccessDenied: details = "(pid=%s)" % (self.pid) else: details = "(pid=%s, name=%s)" % (pid, name) return "%s.%s%s" % (self.__class__.__module__, self.__class__.__name__, details) def __repr__(self): return "<%s at %s>" % (self.__str__(), id(self)) def __eq__(self, other): # Test for equality with another Process object based # on PID and creation time. if not isinstance(other, Process): return NotImplemented return self._ident == other._ident def __ne__(self, other): return not self == other def __hash__(self): if self._hash is None: self._hash = hash(self._ident) return self._hash # --- utility methods def as_dict(self, attrs=None, ad_value=None): """Utility method returning process information as a hashable dictionary. If 'attrs' is specified it must be a list of strings reflecting available Process class' attribute names (e.g. ['cpu_times', 'name']) else all public (read only) attributes are assumed. 'ad_value' is the value which gets assigned in case AccessDenied exception is raised when retrieving that particular process information. """ excluded_names = set( ['send_signal', 'suspend', 'resume', 'terminate', 'kill', 'wait', 'is_running', 'as_dict', 'parent', 'children', 'rlimit']) retdict = dict() ls = set(attrs or [x for x in dir(self) if not x.startswith('get')]) for name in ls: if name.startswith('_'): continue if name.startswith('set_'): continue if name.startswith('get_'): msg = "%s() is deprecated; use %s() instead" % (name, name[4:]) warnings.warn(msg, category=DeprecationWarning, stacklevel=2) name = name[4:] if name in ls: continue if name == 'getcwd': msg = "getcwd() is deprecated; use cwd() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) name = 'cwd' if name in ls: continue if name in excluded_names: continue try: attr = getattr(self, name) if callable(attr): ret = attr() else: ret = attr except AccessDenied: ret = ad_value except NotImplementedError: # in case of not implemented functionality (may happen # on old or exotic systems) we want to crash only if # the user explicitly asked for that particular attr if attrs: raise continue retdict[name] = ret return retdict def parent(self): """Return the parent process as a Process object pre-emptively checking whether PID has been reused. If no parent is known return None. """ ppid = self.ppid() if ppid is not None: try: parent = Process(ppid) if parent.create_time() <= self.create_time(): return parent # ...else ppid has been reused by another process except NoSuchProcess: pass def is_running(self): """Return whether this process is running. It also checks if PID has been reused by another process in which case return False. """ if self._gone: return False try: # Checking if PID is alive is not enough as the PID might # have been reused by another process: we also want to # check process identity. # Process identity / uniqueness over time is greanted by # (PID + creation time) and that is verified in __eq__. return self == Process(self.pid) except NoSuchProcess: self._gone = True return False # --- actual API @property def pid(self): """The process PID.""" return self._pid def ppid(self): """The process parent PID. On Windows the return value is cached after first call. """ # On POSIX we don't want to cache the ppid as it may unexpectedly # change to 1 (init) in case this process turns into a zombie: # https://github.com/giampaolo/psutil/issues/321 # http://stackoverflow.com/questions/356722/ # XXX should we check creation time here rather than in # Process.parent()? if _POSIX: return self._proc.ppid() else: if self._ppid is None: self._ppid = self._proc.ppid() return self._ppid def name(self): """The process name. The return value is cached after first call.""" if self._name is None: name = self._proc.name() if _POSIX and len(name) >= 15: # On UNIX the name gets truncated to the first 15 characters. # If it matches the first part of the cmdline we return that # one instead because it's usually more explicative. # Examples are "gnome-keyring-d" vs. "gnome-keyring-daemon". try: cmdline = self.cmdline() except AccessDenied: pass else: if cmdline: extended_name = os.path.basename(cmdline[0]) if extended_name.startswith(name): name = extended_name self._proc._name = name self._name = name return self._name def exe(self): """The process executable as an absolute path. May also be an empty string. The return value is cached after first call. """ def guess_it(fallback): # try to guess exe from cmdline[0] in absence of a native # exe representation cmdline = self.cmdline() if cmdline and hasattr(os, 'access') and hasattr(os, 'X_OK'): exe = cmdline[0] # the possible exe # Attempt to guess only in case of an absolute path. # It is not safe otherwise as the process might have # changed cwd. if (os.path.isabs(exe) and os.path.isfile(exe) and os.access(exe, os.X_OK)): return exe if isinstance(fallback, AccessDenied): raise fallback return fallback if self._exe is None: try: exe = self._proc.exe() except AccessDenied as err: return guess_it(fallback=err) else: if not exe: # underlying implementation can legitimately return an # empty string; if that's the case we don't want to # raise AD while guessing from the cmdline try: exe = guess_it(fallback=exe) except AccessDenied: pass self._exe = exe return self._exe def cmdline(self): """The command line this process has been called with.""" return self._proc.cmdline() def status(self): """The process current status as a STATUS_* constant.""" return self._proc.status() def username(self): """The name of the user that owns the process. On UNIX this is calculated by using *real* process uid. """ if _POSIX: if pwd is None: # might happen if python was installed from sources raise ImportError( "requires pwd module shipped with standard python") real_uid = self.uids().real try: return pwd.getpwuid(real_uid).pw_name except KeyError: # the uid can't be resolved by the system return str(real_uid) else: return self._proc.username() def create_time(self): """The process creation time as a floating point number expressed in seconds since the epoch, in UTC. The return value is cached after first call. """ if self._create_time is None: self._create_time = self._proc.create_time() return self._create_time def cwd(self): """Process current working directory as an absolute path.""" return self._proc.cwd() def nice(self, value=None): """Get or set process niceness (priority).""" if value is None: return self._proc.nice_get() else: if not self.is_running(): raise NoSuchProcess(self.pid, self._name) self._proc.nice_set(value) if _POSIX: def uids(self): """Return process UIDs as a (real, effective, saved) namedtuple. """ return self._proc.uids() def gids(self): """Return process GIDs as a (real, effective, saved) namedtuple. """ return self._proc.gids() def terminal(self): """The terminal associated with this process, if any, else None. """ return self._proc.terminal() def num_fds(self): """Return the number of file descriptors opened by this process (POSIX only). """ return self._proc.num_fds() # Linux, BSD and Windows only if hasattr(_psplatform.Process, "io_counters"): def io_counters(self): """Return process I/O statistics as a (read_count, write_count, read_bytes, write_bytes) namedtuple. Those are the number of read/write calls performed and the amount of bytes read and written by the process. """ return self._proc.io_counters() # Linux and Windows >= Vista only if hasattr(_psplatform.Process, "ionice_get"): def ionice(self, ioclass=None, value=None): """Get or set process I/O niceness (priority). On Linux 'ioclass' is one of the IOPRIO_CLASS_* constants. 'value' is a number which goes from 0 to 7. The higher the value, the lower the I/O priority of the process. On Windows only 'ioclass' is used and it can be set to 2 (normal), 1 (low) or 0 (very low). Available on Linux and Windows > Vista only. """ if ioclass is None: if value is not None: raise ValueError("'ioclass' must be specified") return self._proc.ionice_get() else: return self._proc.ionice_set(ioclass, value) # Linux only if hasattr(_psplatform.Process, "rlimit"): def rlimit(self, resource, limits=None): """Get or set process resource limits as a (soft, hard) tuple. 'resource' is one of the RLIMIT_* constants. 'limits' is supposed to be a (soft, hard) tuple. See "man prlimit" for further info. Available on Linux only. """ if limits is None: return self._proc.rlimit(resource) else: return self._proc.rlimit(resource, limits) # Windows, Linux and BSD only if hasattr(_psplatform.Process, "cpu_affinity_get"): def cpu_affinity(self, cpus=None): """Get or set process CPU affinity. If specified 'cpus' must be a list of CPUs for which you want to set the affinity (e.g. [0, 1]). (Windows, Linux and BSD only). """ if cpus is None: return self._proc.cpu_affinity_get() else: self._proc.cpu_affinity_set(cpus) if _WINDOWS: def num_handles(self): """Return the number of handles opened by this process (Windows only). """ return self._proc.num_handles() def num_ctx_switches(self): """Return the number of voluntary and involuntary context switches performed by this process. """ return self._proc.num_ctx_switches() def num_threads(self): """Return the number of threads used by this process.""" return self._proc.num_threads() def threads(self): """Return threads opened by process as a list of (id, user_time, system_time) namedtuples representing thread id and thread CPU times (user/system). """ return self._proc.threads() @_assert_pid_not_reused def children(self, recursive=False): """Return the children of this process as a list of Process instances, pre-emptively checking whether PID has been reused. If recursive is True return all the parent descendants. Example (A == this process): A ─┐ │ ├─ B (child) ─┐ │ └─ X (grandchild) ─┐ │ └─ Y (great grandchild) ├─ C (child) └─ D (child) >>> import psutil >>> p = psutil.Process() >>> p.children() B, C, D >>> p.children(recursive=True) B, X, Y, C, D Note that in the example above if process X disappears process Y won't be listed as the reference to process A is lost. """ if hasattr(_psplatform, 'ppid_map'): # Windows only: obtain a {pid:ppid, ...} dict for all running # processes in one shot (faster). ppid_map = _psplatform.ppid_map() else: ppid_map = None ret = [] if not recursive: if ppid_map is None: # 'slow' version, common to all platforms except Windows for p in process_iter(): try: if p.ppid() == self.pid: # if child happens to be older than its parent # (self) it means child's PID has been reused if self.create_time() <= p.create_time(): ret.append(p) except NoSuchProcess: pass else: # Windows only (faster) for pid, ppid in ppid_map.items(): if ppid == self.pid: try: child = Process(pid) # if child happens to be older than its parent # (self) it means child's PID has been reused if self.create_time() <= child.create_time(): ret.append(child) except NoSuchProcess: pass else: # construct a dict where 'values' are all the processes # having 'key' as their parent table = collections.defaultdict(list) if ppid_map is None: for p in process_iter(): try: table[p.ppid()].append(p) except NoSuchProcess: pass else: for pid, ppid in ppid_map.items(): try: p = Process(pid) table[ppid].append(p) except NoSuchProcess: pass # At this point we have a mapping table where table[self.pid] # are the current process' children. # Below, we look for all descendants recursively, similarly # to a recursive function call. checkpids = [self.pid] for pid in checkpids: for child in table[pid]: try: # if child happens to be older than its parent # (self) it means child's PID has been reused intime = self.create_time() <= child.create_time() except NoSuchProcess: pass else: if intime: ret.append(child) if child.pid not in checkpids: checkpids.append(child.pid) return ret def cpu_percent(self, interval=None): """Return a float representing the current process CPU utilization as a percentage. When interval is 0.0 or None (default) compares process times to system CPU times elapsed since last call, returning immediately (non-blocking). That means that the first time this is called it will return a meaningful 0.0 value. When interval is > 0.0 compares process times to system CPU times elapsed before and after the interval (blocking). In this case is recommended for accuracy that this function be called with at least 0.1 seconds between calls. Examples: >>> import psutil >>> p = psutil.Process(os.getpid()) >>> # blocking >>> p.cpu_percent(interval=1) 2.0 >>> # non-blocking (percentage since last call) >>> p.cpu_percent(interval=None) 2.9 >>> """ blocking = interval is not None and interval > 0.0 num_cpus = cpu_count() if _POSIX: timer = lambda: _timer() * num_cpus else: timer = lambda: sum(cpu_times()) if blocking: st1 = timer() pt1 = self._proc.cpu_times() time.sleep(interval) st2 = timer() pt2 = self._proc.cpu_times() else: st1 = self._last_sys_cpu_times pt1 = self._last_proc_cpu_times st2 = timer() pt2 = self._proc.cpu_times() if st1 is None or pt1 is None: self._last_sys_cpu_times = st2 self._last_proc_cpu_times = pt2 return 0.0 delta_proc = (pt2.user - pt1.user) + (pt2.system - pt1.system) delta_time = st2 - st1 # reset values for next call in case of interval == None self._last_sys_cpu_times = st2 self._last_proc_cpu_times = pt2 try: # The utilization split between all CPUs. # Note: a percentage > 100 is legitimate as it can result # from a process with multiple threads running on different # CPU cores, see: # http://stackoverflow.com/questions/1032357 # https://github.com/giampaolo/psutil/issues/474 overall_percent = ((delta_proc / delta_time) * 100) * num_cpus except ZeroDivisionError: # interval was too low return 0.0 else: return round(overall_percent, 1) def cpu_times(self): """Return a (user, system) namedtuple representing the accumulated process time, in seconds. This is the same as os.times() but per-process. """ return self._proc.cpu_times() def memory_info(self): """Return a tuple representing RSS (Resident Set Size) and VMS (Virtual Memory Size) in bytes. On UNIX RSS and VMS are the same values shown by 'ps'. On Windows RSS and VMS refer to "Mem Usage" and "VM Size" columns of taskmgr.exe. """ return self._proc.memory_info() def memory_info_ex(self): """Return a namedtuple with variable fields depending on the platform representing extended memory information about this process. All numbers are expressed in bytes. """ return self._proc.memory_info_ex() def memory_percent(self): """Compare physical system memory to process resident memory (RSS) and calculate process memory utilization as a percentage. """ rss = self._proc.memory_info()[0] # use cached value if available total_phymem = _TOTAL_PHYMEM or virtual_memory().total try: return (rss / float(total_phymem)) * 100 except ZeroDivisionError: return 0.0 def memory_maps(self, grouped=True): """Return process' mapped memory regions as a list of nameduples whose fields are variable depending on the platform. If 'grouped' is True the mapped regions with the same 'path' are grouped together and the different memory fields are summed. If 'grouped' is False every mapped region is shown as a single entity and the namedtuple will also include the mapped region's address space ('addr') and permission set ('perms'). """ it = self._proc.memory_maps() if grouped: d = {} for tupl in it: path = tupl[2] nums = tupl[3:] try: d[path] = map(lambda x, y: x + y, d[path], nums) except KeyError: d[path] = nums nt = _psplatform.pmmap_grouped return [nt(path, *d[path]) for path in d] # NOQA else: nt = _psplatform.pmmap_ext return [nt(*x) for x in it] def open_files(self): """Return files opened by process as a list of (path, fd) namedtuples including the absolute file name and file descriptor number. """ return self._proc.open_files() def connections(self, kind='inet'): """Return connections opened by process as a list of (fd, family, type, laddr, raddr, status) namedtuples. The 'kind' parameter filters for connections that match the following criteria: Kind Value Connections using inet IPv4 and IPv6 inet4 IPv4 inet6 IPv6 tcp TCP tcp4 TCP over IPv4 tcp6 TCP over IPv6 udp UDP udp4 UDP over IPv4 udp6 UDP over IPv6 unix UNIX socket (both UDP and TCP protocols) all the sum of all the possible families and protocols """ return self._proc.connections(kind) if _POSIX: def _send_signal(self, sig): # XXX: according to "man 2 kill" PID 0 has a special # meaning as it refers to <>, so should we prevent # it here? try: os.kill(self.pid, sig) except OSError as err: if err.errno == errno.ESRCH: self._gone = True raise NoSuchProcess(self.pid, self._name) if err.errno == errno.EPERM: raise AccessDenied(self.pid, self._name) raise @_assert_pid_not_reused def send_signal(self, sig): """Send a signal to process pre-emptively checking whether PID has been reused (see signal module constants) . On Windows only SIGTERM is valid and is treated as an alias for kill(). """ if _POSIX: self._send_signal(sig) else: if sig == signal.SIGTERM: self._proc.kill() else: raise ValueError("only SIGTERM is supported on Windows") @_assert_pid_not_reused def suspend(self): """Suspend process execution with SIGSTOP pre-emptively checking whether PID has been reused. On Windows this has the effect ot suspending all process threads. """ if _POSIX: self._send_signal(signal.SIGSTOP) else: self._proc.suspend() @_assert_pid_not_reused def resume(self): """Resume process execution with SIGCONT pre-emptively checking whether PID has been reused. On Windows this has the effect of resuming all process threads. """ if _POSIX: self._send_signal(signal.SIGCONT) else: self._proc.resume() @_assert_pid_not_reused def terminate(self): """Terminate the process with SIGTERM pre-emptively checking whether PID has been reused. On Windows this is an alias for kill(). """ if _POSIX: self._send_signal(signal.SIGTERM) else: self._proc.kill() @_assert_pid_not_reused def kill(self): """Kill the current process with SIGKILL pre-emptively checking whether PID has been reused. """ if _POSIX: self._send_signal(signal.SIGKILL) else: self._proc.kill() def wait(self, timeout=None): """Wait for process to terminate and, if process is a children of os.getpid(), also return its exit code, else None. If the process is already terminated immediately return None instead of raising NoSuchProcess. If timeout (in seconds) is specified and process is still alive raise TimeoutExpired. To wait for multiple Process(es) use psutil.wait_procs(). """ if timeout is not None and not timeout >= 0: raise ValueError("timeout must be a positive integer") return self._proc.wait(timeout) # --- deprecated APIs _locals = set(locals()) @_deprecated_method(replacement='children') def get_children(self): pass @_deprecated_method(replacement='connections') def get_connections(self): pass if "cpu_affinity" in _locals: @_deprecated_method(replacement='cpu_affinity') def get_cpu_affinity(self): pass @_deprecated_method(replacement='cpu_affinity') def set_cpu_affinity(self, cpus): pass @_deprecated_method(replacement='cpu_percent') def get_cpu_percent(self): pass @_deprecated_method(replacement='cpu_times') def get_cpu_times(self): pass @_deprecated_method(replacement='cwd') def getcwd(self): pass @_deprecated_method(replacement='memory_info_ex') def get_ext_memory_info(self): pass if "io_counters" in _locals: @_deprecated_method(replacement='io_counters') def get_io_counters(self): pass if "ionice" in _locals: @_deprecated_method(replacement='ionice') def get_ionice(self): pass @_deprecated_method(replacement='ionice') def set_ionice(self, ioclass, value=None): pass @_deprecated_method(replacement='memory_info') def get_memory_info(self): pass @_deprecated_method(replacement='memory_maps') def get_memory_maps(self): pass @_deprecated_method(replacement='memory_percent') def get_memory_percent(self): pass @_deprecated_method(replacement='nice') def get_nice(self): pass @_deprecated_method(replacement='num_ctx_switches') def get_num_ctx_switches(self): pass if 'num_fds' in _locals: @_deprecated_method(replacement='num_fds') def get_num_fds(self): pass if 'num_handles' in _locals: @_deprecated_method(replacement='num_handles') def get_num_handles(self): pass @_deprecated_method(replacement='num_threads') def get_num_threads(self): pass @_deprecated_method(replacement='open_files') def get_open_files(self): pass if "rlimit" in _locals: @_deprecated_method(replacement='rlimit') def get_rlimit(self): pass @_deprecated_method(replacement='rlimit') def set_rlimit(self, resource, limits): pass @_deprecated_method(replacement='threads') def get_threads(self): pass @_deprecated_method(replacement='nice') def set_nice(self, value): pass del _locals # ===================================================================== # --- Popen class # ===================================================================== class Popen(Process): """A more convenient interface to stdlib subprocess module. It starts a sub process and deals with it exactly as when using subprocess.Popen class but in addition also provides all the properties and methods of psutil.Process class as a unified interface: >>> import psutil >>> from subprocess import PIPE >>> p = psutil.Popen(["python", "-c", "print 'hi'"], stdout=PIPE) >>> p.name() 'python' >>> p.uids() user(real=1000, effective=1000, saved=1000) >>> p.username() 'giampaolo' >>> p.communicate() ('hi\n', None) >>> p.terminate() >>> p.wait(timeout=2) 0 >>> For method names common to both classes such as kill(), terminate() and wait(), psutil.Process implementation takes precedence. Unlike subprocess.Popen this class pre-emptively checks wheter PID has been reused on send_signal(), terminate() and kill() so that you don't accidentally terminate another process, fixing http://bugs.python.org/issue6973. For a complete documentation refer to: http://docs.python.org/library/subprocess.html """ def __init__(self, *args, **kwargs): # Explicitly avoid to raise NoSuchProcess in case the process # spawned by subprocess.Popen terminates too quickly, see: # https://github.com/giampaolo/psutil/issues/193 self.__subproc = subprocess.Popen(*args, **kwargs) self._init(self.__subproc.pid, _ignore_nsp=True) def __dir__(self): return sorted(set(dir(Popen) + dir(subprocess.Popen))) def __getattribute__(self, name): try: return object.__getattribute__(self, name) except AttributeError: try: return object.__getattribute__(self.__subproc, name) except AttributeError: raise AttributeError("%s instance has no attribute '%s'" % (self.__class__.__name__, name)) def wait(self, timeout=None): if self.__subproc.returncode is not None: return self.__subproc.returncode ret = super(Popen, self).wait(timeout) self.__subproc.returncode = ret return ret # ===================================================================== # --- system processes related functions # ===================================================================== def pids(): """Return a list of current running PIDs.""" return _psplatform.pids() def pid_exists(pid): """Return True if given PID exists in the current process list. This is faster than doing "pid in psutil.pids()" and should be preferred. """ if pid < 0: return False elif pid == 0 and _POSIX: # On POSIX we use os.kill() to determine PID existence. # According to "man 2 kill" PID 0 has a special meaning # though: it refers to <> and that is not we want # to do here. return pid in pids() else: return _psplatform.pid_exists(pid) _pmap = {} def process_iter(): """Return a generator yielding a Process instance for all running processes. Every new Process instance is only created once and then cached into an internal table which is updated every time this is used. Cached Process instances are checked for identity so that you're safe in case a PID has been reused by another process, in which case the cached instance is updated. The sorting order in which processes are yielded is based on their PIDs. """ def add(pid): proc = Process(pid) _pmap[proc.pid] = proc return proc def remove(pid): _pmap.pop(pid, None) a = set(pids()) b = set(_pmap.keys()) new_pids = a - b gone_pids = b - a for pid in gone_pids: remove(pid) for pid, proc in sorted(list(_pmap.items()) + list(dict.fromkeys(new_pids).items())): try: if proc is None: # new process yield add(pid) else: # use is_running() to check whether PID has been reused by # another process in which case yield a new Process instance if proc.is_running(): yield proc else: yield add(pid) except NoSuchProcess: remove(pid) except AccessDenied: # Process creation time can't be determined hence there's # no way to tell whether the pid of the cached process # has been reused. Just return the cached version. yield proc def wait_procs(procs, timeout=None, callback=None): """Convenience function which waits for a list of processes to terminate. Return a (gone, alive) tuple indicating which processes are gone and which ones are still alive. The gone ones will have a new 'returncode' attribute indicating process exit status (may be None). 'callback' is a function which gets called every time a process terminates (a Process instance is passed as callback argument). Function will return as soon as all processes terminate or when timeout occurs. Typical use case is: - send SIGTERM to a list of processes - give them some time to terminate - send SIGKILL to those ones which are still alive Example: >>> def on_terminate(proc): ... print("process {} terminated".format(proc)) ... >>> for p in procs: ... p.terminate() ... >>> gone, alive = wait_procs(procs, timeout=3, callback=on_terminate) >>> for p in alive: ... p.kill() """ def check_gone(proc, timeout): try: returncode = proc.wait(timeout=timeout) except TimeoutExpired: pass else: if returncode is not None or not proc.is_running(): proc.returncode = returncode gone.add(proc) if callback is not None: callback(proc) if timeout is not None and not timeout >= 0: msg = "timeout must be a positive integer, got %s" % timeout raise ValueError(msg) gone = set() alive = set(procs) if callback is not None and not callable(callback): raise TypeError("callback %r is not a callable" % callable) if timeout is not None: deadline = _timer() + timeout while alive: if timeout is not None and timeout <= 0: break for proc in alive: # Make sure that every complete iteration (all processes) # will last max 1 sec. # We do this because we don't want to wait too long on a # single process: in case it terminates too late other # processes may disappear in the meantime and their PID # reused. max_timeout = 1.0 / len(alive) if timeout is not None: timeout = min((deadline - _timer()), max_timeout) if timeout <= 0: break check_gone(proc, timeout) else: check_gone(proc, max_timeout) alive = alive - gone if alive: # Last attempt over processes survived so far. # timeout == 0 won't make this function wait any further. for proc in alive: check_gone(proc, 0) alive = alive - gone return (list(gone), list(alive)) # ===================================================================== # --- CPU related functions # ===================================================================== @memoize def cpu_count(logical=True): """Return the number of logical CPUs in the system (same as os.cpu_count() in Python 3.4). If logical is False return the number of physical cores only (hyper thread CPUs are excluded). Return None if undetermined. The return value is cached after first call. If desired cache can be cleared like this: >>> psutil.cpu_count.cache_clear() """ if logical: return _psplatform.cpu_count_logical() else: return _psplatform.cpu_count_physical() def cpu_times(percpu=False): """Return system-wide CPU times as a namedtuple. Every CPU time represents the seconds the CPU has spent in the given mode. The namedtuple's fields availability varies depending on the platform: - user - system - idle - nice (UNIX) - iowait (Linux) - irq (Linux, FreeBSD) - softirq (Linux) - steal (Linux >= 2.6.11) - guest (Linux >= 2.6.24) - guest_nice (Linux >= 3.2.0) When percpu is True return a list of nameduples for each CPU. First element of the list refers to first CPU, second element to second CPU and so on. The order of the list is consistent across calls. """ if not percpu: return _psplatform.cpu_times() else: return _psplatform.per_cpu_times() _last_cpu_times = cpu_times() _last_per_cpu_times = cpu_times(percpu=True) def cpu_percent(interval=None, percpu=False): """Return a float representing the current system-wide CPU utilization as a percentage. When interval is > 0.0 compares system CPU times elapsed before and after the interval (blocking). When interval is 0.0 or None compares system CPU times elapsed since last call or module import, returning immediately (non blocking). That means the first time this is called it will return a meaningless 0.0 value which you should ignore. In this case is recommended for accuracy that this function be called with at least 0.1 seconds between calls. When percpu is True returns a list of floats representing the utilization as a percentage for each CPU. First element of the list refers to first CPU, second element to second CPU and so on. The order of the list is consistent across calls. Examples: >>> # blocking, system-wide >>> psutil.cpu_percent(interval=1) 2.0 >>> >>> # blocking, per-cpu >>> psutil.cpu_percent(interval=1, percpu=True) [2.0, 1.0] >>> >>> # non-blocking (percentage since last call) >>> psutil.cpu_percent(interval=None) 2.9 >>> """ global _last_cpu_times global _last_per_cpu_times blocking = interval is not None and interval > 0.0 def calculate(t1, t2): t1_all = sum(t1) t1_busy = t1_all - t1.idle t2_all = sum(t2) t2_busy = t2_all - t2.idle # this usually indicates a float precision issue if t2_busy <= t1_busy: return 0.0 busy_delta = t2_busy - t1_busy all_delta = t2_all - t1_all busy_perc = (busy_delta / all_delta) * 100 return round(busy_perc, 1) # system-wide usage if not percpu: if blocking: t1 = cpu_times() time.sleep(interval) else: t1 = _last_cpu_times _last_cpu_times = cpu_times() return calculate(t1, _last_cpu_times) # per-cpu usage else: ret = [] if blocking: tot1 = cpu_times(percpu=True) time.sleep(interval) else: tot1 = _last_per_cpu_times _last_per_cpu_times = cpu_times(percpu=True) for t1, t2 in zip(tot1, _last_per_cpu_times): ret.append(calculate(t1, t2)) return ret # Use separate global vars for cpu_times_percent() so that it's # independent from cpu_percent() and they can both be used within # the same program. _last_cpu_times_2 = _last_cpu_times _last_per_cpu_times_2 = _last_per_cpu_times def cpu_times_percent(interval=None, percpu=False): """Same as cpu_percent() but provides utilization percentages for each specific CPU time as is returned by cpu_times(). For instance, on Linux we'll get: >>> cpu_times_percent() cpupercent(user=4.8, nice=0.0, system=4.8, idle=90.5, iowait=0.0, irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0) >>> interval and percpu arguments have the same meaning as in cpu_percent(). """ global _last_cpu_times_2 global _last_per_cpu_times_2 blocking = interval is not None and interval > 0.0 def calculate(t1, t2): nums = [] all_delta = sum(t2) - sum(t1) for field in t1._fields: field_delta = getattr(t2, field) - getattr(t1, field) try: field_perc = (100 * field_delta) / all_delta except ZeroDivisionError: field_perc = 0.0 field_perc = round(field_perc, 1) if _WINDOWS: # XXX # Work around: # https://github.com/giampaolo/psutil/issues/392 # CPU times are always supposed to increase over time # or at least remain the same and that's because time # cannot go backwards. # Surprisingly sometimes this might not be the case on # Windows where 'system' CPU time can be smaller # compared to the previous call, resulting in corrupted # percentages (< 0 or > 100). # I really don't know what to do about that except # forcing the value to 0 or 100. if field_perc > 100.0: field_perc = 100.0 elif field_perc < 0.0: field_perc = 0.0 nums.append(field_perc) return _psplatform.scputimes(*nums) # system-wide usage if not percpu: if blocking: t1 = cpu_times() time.sleep(interval) else: t1 = _last_cpu_times_2 _last_cpu_times_2 = cpu_times() return calculate(t1, _last_cpu_times_2) # per-cpu usage else: ret = [] if blocking: tot1 = cpu_times(percpu=True) time.sleep(interval) else: tot1 = _last_per_cpu_times_2 _last_per_cpu_times_2 = cpu_times(percpu=True) for t1, t2 in zip(tot1, _last_per_cpu_times_2): ret.append(calculate(t1, t2)) return ret # ===================================================================== # --- system memory related functions # ===================================================================== def virtual_memory(): """Return statistics about system memory usage as a namedtuple including the following fields, expressed in bytes: - total: total physical memory available. - available: the actual amount of available memory that can be given instantly to processes that request more memory in bytes; this is calculated by summing different memory values depending on the platform (e.g. free + buffers + cached on Linux) and it is supposed to be used to monitor actual memory usage in a cross platform fashion. - percent: the percentage usage calculated as (total - available) / total * 100 - used: memory used, calculated differently depending on the platform and designed for informational purposes only: OSX: active + inactive + wired BSD: active + wired + cached LINUX: total - free - free: memory not being used at all (zeroed) that is readily available; note that this doesn't reflect the actual memory available (use 'available' instead) Platform-specific fields: - active (UNIX): memory currently in use or very recently used, and so it is in RAM. - inactive (UNIX): memory that is marked as not used. - buffers (BSD, Linux): cache for things like file system metadata. - cached (BSD, OSX): cache for various things. - wired (OSX, BSD): memory that is marked to always stay in RAM. It is never moved to disk. - shared (BSD): memory that may be simultaneously accessed by multiple processes. The sum of 'used' and 'available' does not necessarily equal total. On Windows 'available' and 'free' are the same. """ global _TOTAL_PHYMEM ret = _psplatform.virtual_memory() # cached for later use in Process.memory_percent() _TOTAL_PHYMEM = ret.total return ret def swap_memory(): """Return system swap memory statistics as a namedtuple including the following fields: - total: total swap memory in bytes - used: used swap memory in bytes - free: free swap memory in bytes - percent: the percentage usage - sin: no. of bytes the system has swapped in from disk (cumulative) - sout: no. of bytes the system has swapped out from disk (cumulative) 'sin' and 'sout' on Windows are meaningless and always set to 0. """ return _psplatform.swap_memory() # ===================================================================== # --- disks/paritions related functions # ===================================================================== def disk_usage(path): """Return disk usage statistics about the given path as a namedtuple including total, used and free space expressed in bytes plus the percentage usage. """ return _psplatform.disk_usage(path) def disk_partitions(all=False): """Return mounted partitions as a list of (device, mountpoint, fstype, opts) namedtuple. 'opts' field is a raw string separated by commas indicating mount options which may vary depending on the platform. If "all" parameter is False return physical devices only and ignore all others. """ return _psplatform.disk_partitions(all) def disk_io_counters(perdisk=False): """Return system disk I/O statistics as a namedtuple including the following fields: - read_count: number of reads - write_count: number of writes - read_bytes: number of bytes read - write_bytes: number of bytes written - read_time: time spent reading from disk (in milliseconds) - write_time: time spent writing to disk (in milliseconds) If perdisk is True return the same information for every physical disk installed on the system as a dictionary with partition names as the keys and the namedutuple described above as the values. On recent Windows versions 'diskperf -y' command may need to be executed first otherwise this function won't find any disk. """ rawdict = _psplatform.disk_io_counters() if not rawdict: raise RuntimeError("couldn't find any physical disk") if perdisk: for disk, fields in rawdict.items(): rawdict[disk] = _nt_sys_diskio(*fields) return rawdict else: return _nt_sys_diskio(*[sum(x) for x in zip(*rawdict.values())]) # ===================================================================== # --- network related functions # ===================================================================== def net_io_counters(pernic=False): """Return network I/O statistics as a namedtuple including the following fields: - bytes_sent: number of bytes sent - bytes_recv: number of bytes received - packets_sent: number of packets sent - packets_recv: number of packets received - errin: total number of errors while receiving - errout: total number of errors while sending - dropin: total number of incoming packets which were dropped - dropout: total number of outgoing packets which were dropped (always 0 on OSX and BSD) If pernic is True return the same information for every network interface installed on the system as a dictionary with network interface names as the keys and the namedtuple described above as the values. """ rawdict = _psplatform.net_io_counters() if not rawdict: raise RuntimeError("couldn't find any network interface") if pernic: for nic, fields in rawdict.items(): rawdict[nic] = _nt_sys_netio(*fields) return rawdict else: return _nt_sys_netio(*[sum(x) for x in zip(*rawdict.values())]) def net_connections(kind='inet'): """Return system-wide connections as a list of (fd, family, type, laddr, raddr, status, pid) namedtuples. In case of limited privileges 'fd' and 'pid' may be set to -1 and None respectively. The 'kind' parameter filters for connections that fit the following criteria: Kind Value Connections using inet IPv4 and IPv6 inet4 IPv4 inet6 IPv6 tcp TCP tcp4 TCP over IPv4 tcp6 TCP over IPv6 udp UDP udp4 UDP over IPv4 udp6 UDP over IPv6 unix UNIX socket (both UDP and TCP protocols) all the sum of all the possible families and protocols """ return _psplatform.net_connections(kind) # ===================================================================== # --- other system related functions # ===================================================================== def boot_time(): """Return the system boot time expressed in seconds since the epoch. This is also available as psutil.BOOT_TIME. """ # Note: we are not caching this because it is subject to # system clock updates. return _psplatform.boot_time() def users(): """Return users currently connected on the system as a list of namedtuples including the following fields. - user: the name of the user - terminal: the tty or pseudo-tty associated with the user, if any. - host: the host name associated with the entry, if any. - started: the creation time as a floating point number expressed in seconds since the epoch. """ return _psplatform.users() # ===================================================================== # --- deprecated functions # ===================================================================== @_deprecated(replacement="psutil.pids()") def get_pid_list(): return pids() @_deprecated(replacement="list(process_iter())") def get_process_list(): return list(process_iter()) @_deprecated(replacement="psutil.users()") def get_users(): return users() @_deprecated(replacement="psutil.virtual_memory()") def phymem_usage(): """Return the amount of total, used and free physical memory on the system in bytes plus the percentage usage. Deprecated; use psutil.virtual_memory() instead. """ return virtual_memory() @_deprecated(replacement="psutil.swap_memory()") def virtmem_usage(): return swap_memory() @_deprecated(replacement="psutil.phymem_usage().free") def avail_phymem(): return phymem_usage().free @_deprecated(replacement="psutil.phymem_usage().used") def used_phymem(): return phymem_usage().used @_deprecated(replacement="psutil.virtmem_usage().total") def total_virtmem(): return virtmem_usage().total @_deprecated(replacement="psutil.virtmem_usage().used") def used_virtmem(): return virtmem_usage().used @_deprecated(replacement="psutil.virtmem_usage().free") def avail_virtmem(): return virtmem_usage().free @_deprecated(replacement="psutil.net_io_counters()") def network_io_counters(pernic=False): return net_io_counters(pernic) def test(): """List info of all currently running processes emulating ps aux output. """ import datetime today_day = datetime.date.today() templ = "%-10s %5s %4s %4s %7s %7s %-13s %5s %7s %s" attrs = ['pid', 'cpu_percent', 'memory_percent', 'name', 'cpu_times', 'create_time', 'memory_info'] if _POSIX: attrs.append('uids') attrs.append('terminal') print(templ % ("USER", "PID", "%CPU", "%MEM", "VSZ", "RSS", "TTY", "START", "TIME", "COMMAND")) for p in process_iter(): try: pinfo = p.as_dict(attrs, ad_value='') except NoSuchProcess: pass else: if pinfo['create_time']: ctime = datetime.datetime.fromtimestamp(pinfo['create_time']) if ctime.date() == today_day: ctime = ctime.strftime("%H:%M") else: ctime = ctime.strftime("%b%d") else: ctime = '' cputime = time.strftime("%M:%S", time.localtime(sum(pinfo['cpu_times']))) try: user = p.username() except KeyError: if _POSIX: if pinfo['uids']: user = str(pinfo['uids'].real) else: user = '' else: raise except Error: user = '' if _WINDOWS and '\\' in user: user = user.split('\\')[1] vms = pinfo['memory_info'] and \ int(pinfo['memory_info'].vms / 1024) or '?' rss = pinfo['memory_info'] and \ int(pinfo['memory_info'].rss / 1024) or '?' memp = pinfo['memory_percent'] and \ round(pinfo['memory_percent'], 1) or '?' print(templ % ( user[:10], pinfo['pid'], pinfo['cpu_percent'], memp, vms, rss, pinfo.get('terminal', '') or '?', ctime, cputime, pinfo['name'].strip() or '?')) def _replace_module(): """Dirty hack to replace the module object in order to access deprecated module constants, see: http://www.dr-josiah.com/2013/12/properties-on-python-modules.html """ class ModuleWrapper(object): def __repr__(self): return repr(self._module) __str__ = __repr__ @property def NUM_CPUS(self): msg = "NUM_CPUS constant is deprecated; use cpu_count() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return cpu_count() @property def BOOT_TIME(self): msg = "BOOT_TIME constant is deprecated; use boot_time() instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return boot_time() @property def TOTAL_PHYMEM(self): msg = "TOTAL_PHYMEM constant is deprecated; " \ "use virtual_memory().total instead" warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return virtual_memory().total mod = ModuleWrapper() mod.__dict__ = globals() mod._module = sys.modules[__name__] sys.modules[__name__] = mod _replace_module() del memoize, division, _replace_module if sys.version_info < (3, 0): del num if __name__ == "__main__": test() ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_common.py ================================================ # /usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Common objects shared by all _ps* modules.""" from __future__ import division import errno import functools import os import socket import stat import warnings try: import threading except ImportError: import dummy_threading as threading from collections import namedtuple from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM # --- constants AF_INET6 = getattr(socket, 'AF_INET6', None) AF_UNIX = getattr(socket, 'AF_UNIX', None) STATUS_RUNNING = "running" STATUS_SLEEPING = "sleeping" STATUS_DISK_SLEEP = "disk-sleep" STATUS_STOPPED = "stopped" STATUS_TRACING_STOP = "tracing-stop" STATUS_ZOMBIE = "zombie" STATUS_DEAD = "dead" STATUS_WAKE_KILL = "wake-kill" STATUS_WAKING = "waking" STATUS_IDLE = "idle" # BSD STATUS_LOCKED = "locked" # BSD STATUS_WAITING = "waiting" # BSD CONN_ESTABLISHED = "ESTABLISHED" CONN_SYN_SENT = "SYN_SENT" CONN_SYN_RECV = "SYN_RECV" CONN_FIN_WAIT1 = "FIN_WAIT1" CONN_FIN_WAIT2 = "FIN_WAIT2" CONN_TIME_WAIT = "TIME_WAIT" CONN_CLOSE = "CLOSE" CONN_CLOSE_WAIT = "CLOSE_WAIT" CONN_LAST_ACK = "LAST_ACK" CONN_LISTEN = "LISTEN" CONN_CLOSING = "CLOSING" CONN_NONE = "NONE" # --- functions def usage_percent(used, total, _round=None): """Calculate percentage usage of 'used' against 'total'.""" try: ret = (used / total) * 100 except ZeroDivisionError: ret = 0 if _round is not None: return round(ret, _round) else: return ret def memoize(fun): """A simple memoize decorator for functions supporting (hashable) positional arguments. It also provides a cache_clear() function for clearing the cache: >>> @memoize ... def foo() ... return 1 ... >>> foo() 1 >>> foo.cache_clear() >>> """ @functools.wraps(fun) def wrapper(*args, **kwargs): key = (args, frozenset(sorted(kwargs.items()))) lock.acquire() try: try: return cache[key] except KeyError: ret = cache[key] = fun(*args, **kwargs) finally: lock.release() return ret def cache_clear(): """Clear cache.""" lock.acquire() try: cache.clear() finally: lock.release() lock = threading.RLock() cache = {} wrapper.cache_clear = cache_clear return wrapper # http://code.activestate.com/recipes/577819-deprecated-decorator/ def deprecated(replacement=None): """A decorator which can be used to mark functions as deprecated.""" def outer(fun): msg = "psutil.%s is deprecated" % fun.__name__ if replacement is not None: msg += "; use %s instead" % replacement if fun.__doc__ is None: fun.__doc__ = msg @functools.wraps(fun) def inner(*args, **kwargs): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return fun(*args, **kwargs) return inner return outer def deprecated_method(replacement): """A decorator which can be used to mark a method as deprecated 'replcement' is the method name which will be called instead. """ def outer(fun): msg = "%s() is deprecated; use %s() instead" % ( fun.__name__, replacement) if fun.__doc__ is None: fun.__doc__ = msg @functools.wraps(fun) def inner(self, *args, **kwargs): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) return getattr(self, replacement)(*args, **kwargs) return inner return outer def isfile_strict(path): """Same as os.path.isfile() but does not swallow EACCES / EPERM exceptions, see: http://mail.python.org/pipermail/python-dev/2012-June/120787.html """ try: st = os.stat(path) except OSError as err: if err.errno in (errno.EPERM, errno.EACCES): raise return False else: return stat.S_ISREG(st.st_mode) # --- Process.connections() 'kind' parameter mapping conn_tmap = { "all": ([AF_INET, AF_INET6, AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]), "tcp": ([AF_INET, AF_INET6], [SOCK_STREAM]), "tcp4": ([AF_INET], [SOCK_STREAM]), "udp": ([AF_INET, AF_INET6], [SOCK_DGRAM]), "udp4": ([AF_INET], [SOCK_DGRAM]), "inet": ([AF_INET, AF_INET6], [SOCK_STREAM, SOCK_DGRAM]), "inet4": ([AF_INET], [SOCK_STREAM, SOCK_DGRAM]), "inet6": ([AF_INET6], [SOCK_STREAM, SOCK_DGRAM]), } if AF_INET6 is not None: conn_tmap.update({ "tcp6": ([AF_INET6], [SOCK_STREAM]), "udp6": ([AF_INET6], [SOCK_DGRAM]), }) if AF_UNIX is not None: conn_tmap.update({ "unix": ([AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]), }) del AF_INET, AF_INET6, AF_UNIX, SOCK_STREAM, SOCK_DGRAM, socket # --- namedtuples for psutil.* system-related functions # psutil.swap_memory() sswap = namedtuple('sswap', ['total', 'used', 'free', 'percent', 'sin', 'sout']) # psutil.disk_usage() sdiskusage = namedtuple('sdiskusage', ['total', 'used', 'free', 'percent']) # psutil.disk_io_counters() sdiskio = namedtuple('sdiskio', ['read_count', 'write_count', 'read_bytes', 'write_bytes', 'read_time', 'write_time']) # psutil.disk_partitions() sdiskpart = namedtuple('sdiskpart', ['device', 'mountpoint', 'fstype', 'opts']) # psutil.net_io_counters() snetio = namedtuple('snetio', ['bytes_sent', 'bytes_recv', 'packets_sent', 'packets_recv', 'errin', 'errout', 'dropin', 'dropout']) # psutil.users() suser = namedtuple('suser', ['name', 'terminal', 'host', 'started']) # psutil.net_connections() sconn = namedtuple('sconn', ['fd', 'family', 'type', 'laddr', 'raddr', 'status', 'pid']) # --- namedtuples for psutil.Process methods # psutil.Process.memory_info() pmem = namedtuple('pmem', ['rss', 'vms']) # psutil.Process.cpu_times() pcputimes = namedtuple('pcputimes', ['user', 'system']) # psutil.Process.open_files() popenfile = namedtuple('popenfile', ['path', 'fd']) # psutil.Process.threads() pthread = namedtuple('pthread', ['id', 'user_time', 'system_time']) # psutil.Process.uids() puids = namedtuple('puids', ['real', 'effective', 'saved']) # psutil.Process.gids() pgids = namedtuple('pgids', ['real', 'effective', 'saved']) # psutil.Process.io_counters() pio = namedtuple('pio', ['read_count', 'write_count', 'read_bytes', 'write_bytes']) # psutil.Process.ionice() pionice = namedtuple('pionice', ['ioclass', 'value']) # psutil.Process.ctx_switches() pctxsw = namedtuple('pctxsw', ['voluntary', 'involuntary']) # --- misc # backward compatibility layer for Process.connections() ntuple class pconn( namedtuple('pconn', ['fd', 'family', 'type', 'laddr', 'raddr', 'status'])): __slots__ = () @property def local_address(self): warnings.warn("'local_address' field is deprecated; use 'laddr'" "instead", category=DeprecationWarning, stacklevel=2) return self.laddr @property def remote_address(self): warnings.warn("'remote_address' field is deprecated; use 'raddr'" "instead", category=DeprecationWarning, stacklevel=2) return self.raddr ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_compat.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Module which provides compatibility with older Python versions.""" __all__ = ["PY3", "int", "long", "xrange", "exec_", "callable", "lru_cache"] import collections import functools import sys try: import __builtin__ except ImportError: import builtins as __builtin__ # py3 PY3 = sys.version_info[0] == 3 if PY3: int = int long = int xrange = range unicode = str basestring = str exec_ = getattr(__builtin__, "exec") else: int = int long = long xrange = xrange unicode = unicode basestring = basestring def exec_(code, globs=None, locs=None): if globs is None: frame = sys._getframe(1) globs = frame.f_globals if locs is None: locs = frame.f_locals del frame elif locs is None: locs = globs exec("""exec code in globs, locs""") # removed in 3.0, reintroduced in 3.2 try: callable = callable except NameError: def callable(obj): return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) # --- stdlib additions # py 3.2 functools.lru_cache # Taken from: http://code.activestate.com/recipes/578078 # Credit: Raymond Hettinger try: from functools import lru_cache except ImportError: try: from threading import RLock except ImportError: from dummy_threading import RLock _CacheInfo = collections.namedtuple( "CacheInfo", ["hits", "misses", "maxsize", "currsize"]) class _HashedSeq(list): __slots__ = 'hashvalue' def __init__(self, tup, hash=hash): self[:] = tup self.hashvalue = hash(tup) def __hash__(self): return self.hashvalue def _make_key(args, kwds, typed, kwd_mark=(object(), ), fasttypes=set((int, str, frozenset, type(None))), sorted=sorted, tuple=tuple, type=type, len=len): key = args if kwds: sorted_items = sorted(kwds.items()) key += kwd_mark for item in sorted_items: key += item if typed: key += tuple(type(v) for v in args) if kwds: key += tuple(type(v) for k, v in sorted_items) elif len(key) == 1 and type(key[0]) in fasttypes: return key[0] return _HashedSeq(key) def lru_cache(maxsize=100, typed=False): """Least-recently-used cache decorator, see: http://docs.python.org/3/library/functools.html#functools.lru_cache """ def decorating_function(user_function): cache = dict() stats = [0, 0] HITS, MISSES = 0, 1 make_key = _make_key cache_get = cache.get _len = len lock = RLock() root = [] root[:] = [root, root, None, None] nonlocal_root = [root] PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 if maxsize == 0: def wrapper(*args, **kwds): result = user_function(*args, **kwds) stats[MISSES] += 1 return result elif maxsize is None: def wrapper(*args, **kwds): key = make_key(args, kwds, typed) result = cache_get(key, root) if result is not root: stats[HITS] += 1 return result result = user_function(*args, **kwds) cache[key] = result stats[MISSES] += 1 return result else: def wrapper(*args, **kwds): if kwds or typed: key = make_key(args, kwds, typed) else: key = args lock.acquire() try: link = cache_get(key) if link is not None: root, = nonlocal_root link_prev, link_next, key, result = link link_prev[NEXT] = link_next link_next[PREV] = link_prev last = root[PREV] last[NEXT] = root[PREV] = link link[PREV] = last link[NEXT] = root stats[HITS] += 1 return result finally: lock.release() result = user_function(*args, **kwds) lock.acquire() try: root, = nonlocal_root if key in cache: pass elif _len(cache) >= maxsize: oldroot = root oldroot[KEY] = key oldroot[RESULT] = result root = nonlocal_root[0] = oldroot[NEXT] oldkey = root[KEY] root[KEY] = root[RESULT] = None del cache[oldkey] cache[key] = oldroot else: last = root[PREV] link = [last, root, key, result] last[NEXT] = root[PREV] = cache[key] = link stats[MISSES] += 1 finally: lock.release() return result def cache_info(): """Report cache statistics""" lock.acquire() try: return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache)) finally: lock.release() def cache_clear(): """Clear the cache and cache statistics""" lock.acquire() try: cache.clear() root = nonlocal_root[0] root[:] = [root, root, None, None] stats[:] = [0, 0] finally: lock.release() wrapper.__wrapped__ = user_function wrapper.cache_info = cache_info wrapper.cache_clear = cache_clear return functools.update_wrapper(wrapper, user_function) return decorating_function ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_psbsd.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """FreeBSD platform implementation.""" import errno import functools import os import sys from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import conn_tmap, usage_percent import _psutil_bsd as cext import _psutil_posix __extra__all__ = [] # --- constants PROC_STATUSES = { cext.SSTOP: _common.STATUS_STOPPED, cext.SSLEEP: _common.STATUS_SLEEPING, cext.SRUN: _common.STATUS_RUNNING, cext.SIDL: _common.STATUS_IDLE, cext.SWAIT: _common.STATUS_WAITING, cext.SLOCK: _common.STATUS_LOCKED, cext.SZOMB: _common.STATUS_ZOMBIE, } TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } PAGESIZE = os.sysconf("SC_PAGE_SIZE") # extend base mem ntuple with BSD-specific memory metrics svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'buffers', 'cached', 'shared', 'wired']) scputimes = namedtuple( 'scputimes', ['user', 'nice', 'system', 'idle', 'irq']) pextmem = namedtuple('pextmem', ['rss', 'vms', 'text', 'data', 'stack']) pmmap_grouped = namedtuple( 'pmmap_grouped', 'path rss, private, ref_count, shadow_count') pmmap_ext = namedtuple( 'pmmap_ext', 'addr, perms path rss, private, ref_count, shadow_count') # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None def virtual_memory(): """System virtual memory as a namedtuple.""" mem = cext.virtual_mem() total, free, active, inactive, wired, cached, buffers, shared = mem avail = inactive + cached + free used = active + wired + cached percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, buffers, cached, shared, wired) def swap_memory(): """System swap memory as (total, used, free, sin, sout) namedtuple.""" total, used, free, sin, sout = [x * PAGESIZE for x in cext.swap_mem()] percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin, sout) def cpu_times(): """Return system per-CPU times as a namedtuple""" user, nice, system, idle, irq = cext.cpu_times() return scputimes(user, nice, system, idle, irq) if hasattr(cext, "per_cpu_times"): def per_cpu_times(): """Return system CPU times as a namedtuple""" ret = [] for cpu_t in cext.per_cpu_times(): user, nice, system, idle, irq = cpu_t item = scputimes(user, nice, system, idle, irq) ret.append(item) return ret else: # XXX # Ok, this is very dirty. # On FreeBSD < 8 we cannot gather per-cpu information, see: # https://github.com/giampaolo/psutil/issues/226 # If num cpus > 1, on first call we return single cpu times to avoid a # crash at psutil import time. # Next calls will fail with NotImplementedError def per_cpu_times(): if cpu_count_logical() == 1: return [cpu_times()] if per_cpu_times.__called__: raise NotImplementedError("supported only starting from FreeBSD 8") per_cpu_times.__called__ = True return [cpu_times()] per_cpu_times.__called__ = False def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" # From the C module we'll get an XML string similar to this: # http://manpages.ubuntu.com/manpages/precise/man4/smp.4freebsd.html # We may get None in case "sysctl kern.sched.topology_spec" # is not supported on this BSD version, in which case we'll mimic # os.cpu_count() and return None. s = cext.cpu_count_phys() if s is not None: # get rid of padding chars appended at the end of the string index = s.rfind("") if index != -1: s = s[:index + 9] if sys.version_info >= (2, 5): import xml.etree.ElementTree as ET root = ET.fromstring(s) return len(root.findall('group/children/group/cpu')) or None else: s = s[s.find(''):] return s.count("> if err.errno in (errno.EINVAL, errno.EDEADLK): allcpus = tuple(range(len(per_cpu_times()))) for cpu in cpus: if cpu not in allcpus: raise ValueError("invalid CPU #%i (choose between %s)" % (cpu, allcpus)) raise ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_pslinux.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Linux platform implementation.""" from __future__ import division import base64 import errno import functools import os import re import socket import struct import sys import warnings from collections import namedtuple, defaultdict from psutil import _common from psutil import _psposix from psutil._common import (isfile_strict, usage_percent, deprecated) from psutil._compat import PY3 import _psutil_linux as cext import _psutil_posix __extra__all__ = [ # io prio constants "IOPRIO_CLASS_NONE", "IOPRIO_CLASS_RT", "IOPRIO_CLASS_BE", "IOPRIO_CLASS_IDLE", # connection status constants "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1", "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT", "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", # other "phymem_buffers", "cached_phymem"] # --- constants HAS_PRLIMIT = hasattr(cext, "linux_prlimit") # RLIMIT_* constants, not guaranteed to be present on all kernels if HAS_PRLIMIT: for name in dir(cext): if name.startswith('RLIM'): __extra__all__.append(name) # Number of clock ticks per second CLOCK_TICKS = os.sysconf("SC_CLK_TCK") PAGESIZE = os.sysconf("SC_PAGE_SIZE") BOOT_TIME = None # set later DEFAULT_ENCODING = sys.getdefaultencoding() # ioprio_* constants http://linux.die.net/man/2/ioprio_get IOPRIO_CLASS_NONE = 0 IOPRIO_CLASS_RT = 1 IOPRIO_CLASS_BE = 2 IOPRIO_CLASS_IDLE = 3 # taken from /fs/proc/array.c PROC_STATUSES = { "R": _common.STATUS_RUNNING, "S": _common.STATUS_SLEEPING, "D": _common.STATUS_DISK_SLEEP, "T": _common.STATUS_STOPPED, "t": _common.STATUS_TRACING_STOP, "Z": _common.STATUS_ZOMBIE, "X": _common.STATUS_DEAD, "x": _common.STATUS_DEAD, "K": _common.STATUS_WAKE_KILL, "W": _common.STATUS_WAKING } # http://students.mimuw.edu.pl/lxr/source/include/net/tcp_states.h TCP_STATUSES = { "01": _common.CONN_ESTABLISHED, "02": _common.CONN_SYN_SENT, "03": _common.CONN_SYN_RECV, "04": _common.CONN_FIN_WAIT1, "05": _common.CONN_FIN_WAIT2, "06": _common.CONN_TIME_WAIT, "07": _common.CONN_CLOSE, "08": _common.CONN_CLOSE_WAIT, "09": _common.CONN_LAST_ACK, "0A": _common.CONN_LISTEN, "0B": _common.CONN_CLOSING } # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- named tuples def _get_cputimes_fields(): """Return a namedtuple of variable fields depending on the CPU times available on this Linux kernel version which may be: (user, nice, system, idle, iowait, irq, softirq, [steal, [guest, [guest_nice]]]) """ with open('/proc/stat', 'rb') as f: values = f.readline().split()[1:] fields = ['user', 'nice', 'system', 'idle', 'iowait', 'irq', 'softirq'] vlen = len(values) if vlen >= 8: # Linux >= 2.6.11 fields.append('steal') if vlen >= 9: # Linux >= 2.6.24 fields.append('guest') if vlen >= 10: # Linux >= 3.2.0 fields.append('guest_nice') return fields scputimes = namedtuple('scputimes', _get_cputimes_fields()) svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'buffers', 'cached']) pextmem = namedtuple('pextmem', 'rss vms shared text lib data dirty') pmmap_grouped = namedtuple( 'pmmap_grouped', ['path', 'rss', 'size', 'pss', 'shared_clean', 'shared_dirty', 'private_clean', 'private_dirty', 'referenced', 'anonymous', 'swap']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # --- system memory def virtual_memory(): total, free, buffers, shared, _, _ = cext.linux_sysinfo() cached = active = inactive = None with open('/proc/meminfo', 'rb') as f: for line in f: if line.startswith(b"Cached:"): cached = int(line.split()[1]) * 1024 elif line.startswith(b"Active:"): active = int(line.split()[1]) * 1024 elif line.startswith(b"Inactive:"): inactive = int(line.split()[1]) * 1024 if (cached is not None and active is not None and inactive is not None): break else: # we might get here when dealing with exotic Linux flavors, see: # https://github.com/giampaolo/psutil/issues/313 msg = "'cached', 'active' and 'inactive' memory stats couldn't " \ "be determined and were set to 0" warnings.warn(msg, RuntimeWarning) cached = active = inactive = 0 avail = free + buffers + cached used = total - free percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, buffers, cached) def swap_memory(): _, _, _, _, total, free = cext.linux_sysinfo() used = total - free percent = usage_percent(used, total, _round=1) # get pgin/pgouts with open("/proc/vmstat", "rb") as f: sin = sout = None for line in f: # values are expressed in 4 kilo bytes, we want bytes instead if line.startswith(b'pswpin'): sin = int(line.split(b' ')[1]) * 4 * 1024 elif line.startswith(b'pswpout'): sout = int(line.split(b' ')[1]) * 4 * 1024 if sin is not None and sout is not None: break else: # we might get here when dealing with exotic Linux flavors, see: # https://github.com/giampaolo/psutil/issues/313 msg = "'sin' and 'sout' swap memory stats couldn't " \ "be determined and were set to 0" warnings.warn(msg, RuntimeWarning) sin = sout = 0 return _common.sswap(total, used, free, percent, sin, sout) @deprecated(replacement='psutil.virtual_memory().cached') def cached_phymem(): return virtual_memory().cached @deprecated(replacement='psutil.virtual_memory().buffers') def phymem_buffers(): return virtual_memory().buffers # --- CPUs def cpu_times(): """Return a named tuple representing the following system-wide CPU times: (user, nice, system, idle, iowait, irq, softirq [steal, [guest, [guest_nice]]]) Last 3 fields may not be available on all Linux kernel versions. """ with open('/proc/stat', 'rb') as f: values = f.readline().split() fields = values[1:len(scputimes._fields) + 1] fields = [float(x) / CLOCK_TICKS for x in fields] return scputimes(*fields) def per_cpu_times(): """Return a list of namedtuple representing the CPU times for every CPU available on the system. """ cpus = [] with open('/proc/stat', 'rb') as f: # get rid of the first line which refers to system wide CPU stats f.readline() for line in f: if line.startswith(b'cpu'): values = line.split() fields = values[1:len(scputimes._fields) + 1] fields = [float(x) / CLOCK_TICKS for x in fields] entry = scputimes(*fields) cpus.append(entry) return cpus def cpu_count_logical(): """Return the number of logical CPUs in the system.""" try: return os.sysconf("SC_NPROCESSORS_ONLN") except ValueError: # as a second fallback we try to parse /proc/cpuinfo num = 0 with open('/proc/cpuinfo', 'rb') as f: for line in f: if line.lower().startswith(b'processor'): num += 1 # unknown format (e.g. amrel/sparc architectures), see: # https://github.com/giampaolo/psutil/issues/200 # try to parse /proc/stat as a last resort if num == 0: search = re.compile('cpu\d') with open('/proc/stat', 'rt') as f: for line in f: line = line.split(' ')[0] if search.match(line): num += 1 if num == 0: # mimic os.cpu_count() return None return num def cpu_count_physical(): """Return the number of physical CPUs in the system.""" with open('/proc/cpuinfo', 'rb') as f: found = set() for line in f: if line.lower().startswith(b'physical id'): found.add(line.strip()) # mimic os.cpu_count() return len(found) if found else None # --- other system functions def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() for item in rawlist: user, tty, hostname, tstamp, user_process = item # note: the underlying C function includes entries about # system boot, run level and others. We might want # to use them in the future. if not user_process: continue if hostname == ':0.0': hostname = 'localhost' nt = _common.suser(user, tty or None, hostname, tstamp) retlist.append(nt) return retlist def boot_time(): """Return the system boot time expressed in seconds since the epoch.""" global BOOT_TIME with open('/proc/stat', 'rb') as f: for line in f: if line.startswith(b'btime'): ret = float(line.strip().split()[1]) BOOT_TIME = ret return ret raise RuntimeError("line 'btime' not found") # --- processes def pids(): """Returns a list of PIDs currently running on the system.""" return [int(x) for x in os.listdir(b'/proc') if x.isdigit()] def pid_exists(pid): """Check For the existence of a unix pid.""" return _psposix.pid_exists(pid) # --- network class Connections: """A wrapper on top of /proc/net/* files, retrieving per-process and system-wide open connections (TCP, UDP, UNIX) similarly to "netstat -an". Note: in case of UNIX sockets we're only able to determine the local endpoint/path, not the one it's connected to. According to [1] it would be possible but not easily. [1] http://serverfault.com/a/417946 """ def __init__(self): tcp4 = ("tcp", socket.AF_INET, socket.SOCK_STREAM) tcp6 = ("tcp6", socket.AF_INET6, socket.SOCK_STREAM) udp4 = ("udp", socket.AF_INET, socket.SOCK_DGRAM) udp6 = ("udp6", socket.AF_INET6, socket.SOCK_DGRAM) unix = ("unix", socket.AF_UNIX, None) self.tmap = { "all": (tcp4, tcp6, udp4, udp6, unix), "tcp": (tcp4, tcp6), "tcp4": (tcp4,), "tcp6": (tcp6,), "udp": (udp4, udp6), "udp4": (udp4,), "udp6": (udp6,), "unix": (unix,), "inet": (tcp4, tcp6, udp4, udp6), "inet4": (tcp4, udp4), "inet6": (tcp6, udp6), } def get_proc_inodes(self, pid): inodes = defaultdict(list) for fd in os.listdir("/proc/%s/fd" % pid): try: inode = os.readlink("/proc/%s/fd/%s" % (pid, fd)) except OSError: # TODO: need comment here continue else: if inode.startswith('socket:['): # the process is using a socket inode = inode[8:][:-1] inodes[inode].append((pid, int(fd))) return inodes def get_all_inodes(self): inodes = {} for pid in pids(): try: inodes.update(self.get_proc_inodes(pid)) except OSError as err: # os.listdir() is gonna raise a lot of access denied # exceptions in case of unprivileged user; that's fine # as we'll just end up returning a connection with PID # and fd set to None anyway. # Both netstat -an and lsof does the same so it's # unlikely we can do any better. # ENOENT just means a PID disappeared on us. if err.errno not in ( errno.ENOENT, errno.ESRCH, errno.EPERM, errno.EACCES): raise return inodes def decode_address(self, addr, family): """Accept an "ip:port" address as displayed in /proc/net/* and convert it into a human readable form, like: "0500000A:0016" -> ("10.0.0.5", 22) "0000000000000000FFFF00000100007F:9E49" -> ("::ffff:127.0.0.1", 40521) The IP address portion is a little or big endian four-byte hexadecimal number; that is, the least significant byte is listed first, so we need to reverse the order of the bytes to convert it to an IP address. The port is represented as a two-byte hexadecimal number. Reference: http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html """ ip, port = addr.split(':') port = int(port, 16) # this usually refers to a local socket in listen mode with # no end-points connected if not port: return () if PY3: ip = ip.encode('ascii') if family == socket.AF_INET: # see: https://github.com/giampaolo/psutil/issues/201 if sys.byteorder == 'little': ip = socket.inet_ntop(family, base64.b16decode(ip)[::-1]) else: ip = socket.inet_ntop(family, base64.b16decode(ip)) else: # IPv6 # old version - let's keep it, just in case... # ip = ip.decode('hex') # return socket.inet_ntop(socket.AF_INET6, # ''.join(ip[i:i+4][::-1] for i in xrange(0, 16, 4))) ip = base64.b16decode(ip) # see: https://github.com/giampaolo/psutil/issues/201 if sys.byteorder == 'little': ip = socket.inet_ntop( socket.AF_INET6, struct.pack('>4I', *struct.unpack('<4I', ip))) else: ip = socket.inet_ntop( socket.AF_INET6, struct.pack('<4I', *struct.unpack('<4I', ip))) return (ip, port) def process_inet(self, file, family, type_, inodes, filter_pid=None): """Parse /proc/net/tcp* and /proc/net/udp* files.""" if file.endswith('6') and not os.path.exists(file): # IPv6 not supported return with open(file, 'rt') as f: f.readline() # skip the first line for line in f: _, laddr, raddr, status, _, _, _, _, _, inode = \ line.split()[:10] if inode in inodes: # We assume inet sockets are unique, so we error # out if there are multiple references to the # same inode. We won't do this for UNIX sockets. if len(inodes[inode]) > 1 and family != socket.AF_UNIX: raise ValueError("ambiguos inode with multiple " "PIDs references") pid, fd = inodes[inode][0] else: pid, fd = None, -1 if filter_pid is not None and filter_pid != pid: continue else: if type_ == socket.SOCK_STREAM: status = TCP_STATUSES[status] else: status = _common.CONN_NONE laddr = self.decode_address(laddr, family) raddr = self.decode_address(raddr, family) yield (fd, family, type_, laddr, raddr, status, pid) def process_unix(self, file, family, inodes, filter_pid=None): """Parse /proc/net/unix files.""" with open(file, 'rt') as f: f.readline() # skip the first line for line in f: tokens = line.split() _, _, _, _, type_, _, inode = tokens[0:7] if inode in inodes: # With UNIX sockets we can have a single inode # referencing many file descriptors. pairs = inodes[inode] else: pairs = [(None, -1)] for pid, fd in pairs: if filter_pid is not None and filter_pid != pid: continue else: if len(tokens) == 8: path = tokens[-1] else: path = "" type_ = int(type_) raddr = None status = _common.CONN_NONE yield (fd, family, type_, path, raddr, status, pid) def retrieve(self, kind, pid=None): if kind not in self.tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in self.tmap]))) if pid is not None: inodes = self.get_proc_inodes(pid) if not inodes: # no connections for this process return [] else: inodes = self.get_all_inodes() ret = [] for f, family, type_ in self.tmap[kind]: if family in (socket.AF_INET, socket.AF_INET6): ls = self.process_inet( "/proc/net/%s" % f, family, type_, inodes, filter_pid=pid) else: ls = self.process_unix( "/proc/net/%s" % f, family, inodes, filter_pid=pid) for fd, family, type_, laddr, raddr, status, bound_pid in ls: if pid: conn = _common.pconn(fd, family, type_, laddr, raddr, status) else: conn = _common.sconn(fd, family, type_, laddr, raddr, status, bound_pid) ret.append(conn) return ret _connections = Connections() def net_connections(kind='inet'): """Return system-wide open connections.""" return _connections.retrieve(kind) def net_io_counters(): """Return network I/O statistics for every network interface installed on the system as a dict of raw tuples. """ with open("/proc/net/dev", "rt") as f: lines = f.readlines() retdict = {} for line in lines[2:]: colon = line.rfind(':') assert colon > 0, repr(line) name = line[:colon].strip() fields = line[colon + 1:].strip().split() bytes_recv = int(fields[0]) packets_recv = int(fields[1]) errin = int(fields[2]) dropin = int(fields[3]) bytes_sent = int(fields[8]) packets_sent = int(fields[9]) errout = int(fields[10]) dropout = int(fields[11]) retdict[name] = (bytes_sent, bytes_recv, packets_sent, packets_recv, errin, errout, dropin, dropout) return retdict # --- disks def disk_io_counters(): """Return disk I/O statistics for every disk installed on the system as a dict of raw tuples. """ # man iostat states that sectors are equivalent with blocks and # have a size of 512 bytes since 2.4 kernels. This value is # needed to calculate the amount of disk I/O in bytes. SECTOR_SIZE = 512 # determine partitions we want to look for partitions = [] with open("/proc/partitions", "rt") as f: lines = f.readlines()[2:] for line in reversed(lines): _, _, _, name = line.split() if name[-1].isdigit(): # we're dealing with a partition (e.g. 'sda1'); 'sda' will # also be around but we want to omit it partitions.append(name) else: if not partitions or not partitions[-1].startswith(name): # we're dealing with a disk entity for which no # partitions have been defined (e.g. 'sda' but # 'sda1' was not around), see: # https://github.com/giampaolo/psutil/issues/338 partitions.append(name) # retdict = {} with open("/proc/diskstats", "rt") as f: lines = f.readlines() for line in lines: # http://www.mjmwired.net/kernel/Documentation/iostats.txt fields = line.split() if len(fields) > 7: _, _, name, reads, _, rbytes, rtime, writes, _, wbytes, wtime = \ fields[:11] else: # from kernel 2.6.0 to 2.6.25 _, _, name, reads, rbytes, writes, wbytes = fields rtime, wtime = 0, 0 if name in partitions: rbytes = int(rbytes) * SECTOR_SIZE wbytes = int(wbytes) * SECTOR_SIZE reads = int(reads) writes = int(writes) rtime = int(rtime) wtime = int(wtime) retdict[name] = (reads, writes, rbytes, wbytes, rtime, wtime) return retdict def disk_partitions(all=False): """Return mounted disk partitions as a list of nameduples""" phydevs = [] with open("/proc/filesystems", "r") as f: for line in f: if not line.startswith("nodev"): phydevs.append(line.strip()) retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: if device == '' or fstype not in phydevs: continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist disk_usage = _psposix.disk_usage # --- decorators def wrap_exceptions(fun): """Decorator which translates bare OSError and IOError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except EnvironmentError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise # ENOENT (no such file or directory) gets raised on open(). # ESRCH (no such process) can get raised on read() if # process is gone in meantime. if err.errno in (errno.ENOENT, errno.ESRCH): raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Linux process implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): fname = "/proc/%s/stat" % self.pid kw = dict(encoding=DEFAULT_ENCODING) if PY3 else dict() with open(fname, "rt", **kw) as f: # XXX - gets changed later and probably needs refactoring return f.read().split(' ')[1].replace('(', '').replace(')', '') def exe(self): try: exe = os.readlink("/proc/%s/exe" % self.pid) except (OSError, IOError) as err: if err.errno in (errno.ENOENT, errno.ESRCH): # no such file error; might be raised also if the # path actually exists for system processes with # low pids (about 0-20) if os.path.lexists("/proc/%s" % self.pid): return "" else: # ok, it is a process which has gone away raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise # readlink() might return paths containing null bytes ('\x00'). # Certain names have ' (deleted)' appended. Usually this is # bogus as the file actually exists. Either way that's not # important as we don't want to discriminate executables which # have been deleted. exe = exe.split('\x00')[0] if exe.endswith(' (deleted)') and not os.path.exists(exe): exe = exe[:-10] return exe @wrap_exceptions def cmdline(self): fname = "/proc/%s/cmdline" % self.pid kw = dict(encoding=DEFAULT_ENCODING) if PY3 else dict() with open(fname, "rt", **kw) as f: return [x for x in f.read().split('\x00') if x] @wrap_exceptions def terminal(self): tmap = _psposix._get_terminal_map() with open("/proc/%s/stat" % self.pid, 'rb') as f: tty_nr = int(f.read().split(b' ')[6]) try: return tmap[tty_nr] except KeyError: return None if os.path.exists('/proc/%s/io' % os.getpid()): @wrap_exceptions def io_counters(self): fname = "/proc/%s/io" % self.pid with open(fname, 'rb') as f: rcount = wcount = rbytes = wbytes = None for line in f: if rcount is None and line.startswith(b"syscr"): rcount = int(line.split()[1]) elif wcount is None and line.startswith(b"syscw"): wcount = int(line.split()[1]) elif rbytes is None and line.startswith(b"read_bytes"): rbytes = int(line.split()[1]) elif wbytes is None and line.startswith(b"write_bytes"): wbytes = int(line.split()[1]) for x in (rcount, wcount, rbytes, wbytes): if x is None: raise NotImplementedError( "couldn't read all necessary info from %r" % fname) return _common.pio(rcount, wcount, rbytes, wbytes) else: def io_counters(self): raise NotImplementedError("couldn't find /proc/%s/io (kernel " "too old?)" % self.pid) @wrap_exceptions def cpu_times(self): with open("/proc/%s/stat" % self.pid, 'rb') as f: st = f.read().strip() # ignore the first two values ("pid (exe)") st = st[st.find(b')') + 2:] values = st.split(b' ') utime = float(values[11]) / CLOCK_TICKS stime = float(values[12]) / CLOCK_TICKS return _common.pcputimes(utime, stime) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) @wrap_exceptions def create_time(self): with open("/proc/%s/stat" % self.pid, 'rb') as f: st = f.read().strip() # ignore the first two values ("pid (exe)") st = st[st.rfind(b')') + 2:] values = st.split(b' ') # According to documentation, starttime is in field 21 and the # unit is jiffies (clock ticks). # We first divide it for clock ticks and then add uptime returning # seconds since the epoch, in UTC. # Also use cached value if available. bt = BOOT_TIME or boot_time() return (float(values[19]) / CLOCK_TICKS) + bt @wrap_exceptions def memory_info(self): with open("/proc/%s/statm" % self.pid, 'rb') as f: vms, rss = f.readline().split()[:2] return _common.pmem(int(rss) * PAGESIZE, int(vms) * PAGESIZE) @wrap_exceptions def memory_info_ex(self): # ============================================================ # | FIELD | DESCRIPTION | AKA | TOP | # ============================================================ # | rss | resident set size | | RES | # | vms | total program size | size | VIRT | # | shared | shared pages (from shared mappings) | | SHR | # | text | text ('code') | trs | CODE | # | lib | library (unused in Linux 2.6) | lrs | | # | data | data + stack | drs | DATA | # | dirty | dirty pages (unused in Linux 2.6) | dt | | # ============================================================ with open("/proc/%s/statm" % self.pid, "rb") as f: vms, rss, shared, text, lib, data, dirty = \ [int(x) * PAGESIZE for x in f.readline().split()[:7]] return pextmem(rss, vms, shared, text, lib, data, dirty) if os.path.exists('/proc/%s/smaps' % os.getpid()): @wrap_exceptions def memory_maps(self): """Return process's mapped memory regions as a list of nameduples. Fields are explained in 'man proc'; here is an updated (Apr 2012) version: http://goo.gl/fmebo """ with open("/proc/%s/smaps" % self.pid, "rt") as f: first_line = f.readline() current_block = [first_line] def get_blocks(): data = {} for line in f: fields = line.split(None, 5) if not fields[0].endswith(':'): # new block section yield (current_block.pop(), data) current_block.append(line) else: try: data[fields[0]] = int(fields[1]) * 1024 except ValueError: if fields[0].startswith('VmFlags:'): # see issue #369 continue else: raise ValueError("don't know how to inte" "rpret line %r" % line) yield (current_block.pop(), data) ls = [] if first_line: # smaps file can be empty for header, data in get_blocks(): hfields = header.split(None, 5) try: addr, perms, offset, dev, inode, path = hfields except ValueError: addr, perms, offset, dev, inode, path = \ hfields + [''] if not path: path = '[anon]' else: path = path.strip() ls.append(( addr, perms, path, data['Rss:'], data.get('Size:', 0), data.get('Pss:', 0), data.get('Shared_Clean:', 0), data.get('Shared_Dirty:', 0), data.get('Private_Clean:', 0), data.get('Private_Dirty:', 0), data.get('Referenced:', 0), data.get('Anonymous:', 0), data.get('Swap:', 0) )) return ls else: def memory_maps(self): msg = "couldn't find /proc/%s/smaps; kernel < 2.6.14 or " \ "CONFIG_MMU kernel configuration option is not enabled" \ % self.pid raise NotImplementedError(msg) @wrap_exceptions def cwd(self): # readlink() might return paths containing null bytes causing # problems when used with other fs-related functions (os.*, # open(), ...) path = os.readlink("/proc/%s/cwd" % self.pid) return path.replace('\x00', '') @wrap_exceptions def num_ctx_switches(self): vol = unvol = None with open("/proc/%s/status" % self.pid, "rb") as f: for line in f: if line.startswith(b"voluntary_ctxt_switches"): vol = int(line.split()[1]) elif line.startswith(b"nonvoluntary_ctxt_switches"): unvol = int(line.split()[1]) if vol is not None and unvol is not None: return _common.pctxsw(vol, unvol) raise NotImplementedError( "'voluntary_ctxt_switches' and 'nonvoluntary_ctxt_switches'" "fields were not found in /proc/%s/status; the kernel is " "probably older than 2.6.23" % self.pid) @wrap_exceptions def num_threads(self): with open("/proc/%s/status" % self.pid, "rb") as f: for line in f: if line.startswith(b"Threads:"): return int(line.split()[1]) raise NotImplementedError("line not found") @wrap_exceptions def threads(self): thread_ids = os.listdir("/proc/%s/task" % self.pid) thread_ids.sort() retlist = [] hit_enoent = False for thread_id in thread_ids: fname = "/proc/%s/task/%s/stat" % (self.pid, thread_id) try: with open(fname, 'rb') as f: st = f.read().strip() except EnvironmentError as err: if err.errno == errno.ENOENT: # no such file or directory; it means thread # disappeared on us hit_enoent = True continue raise # ignore the first two values ("pid (exe)") st = st[st.find(b')') + 2:] values = st.split(b' ') utime = float(values[11]) / CLOCK_TICKS stime = float(values[12]) / CLOCK_TICKS ntuple = _common.pthread(int(thread_id), utime, stime) retlist.append(ntuple) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def nice_get(self): # with open('/proc/%s/stat' % self.pid, 'r') as f: # data = f.read() # return int(data.split()[18]) # Use C implementation return _psutil_posix.getpriority(self.pid) @wrap_exceptions def nice_set(self, value): return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def cpu_affinity_get(self): return cext.proc_cpu_affinity_get(self.pid) @wrap_exceptions def cpu_affinity_set(self, cpus): try: cext.proc_cpu_affinity_set(self.pid, cpus) except OSError as err: if err.errno == errno.EINVAL: allcpus = tuple(range(len(per_cpu_times()))) for cpu in cpus: if cpu not in allcpus: raise ValueError("invalid CPU #%i (choose between %s)" % (cpu, allcpus)) raise # only starting from kernel 2.6.13 if hasattr(cext, "proc_ioprio_get"): @wrap_exceptions def ionice_get(self): ioclass, value = cext.proc_ioprio_get(self.pid) return _common.pionice(ioclass, value) @wrap_exceptions def ionice_set(self, ioclass, value): if ioclass in (IOPRIO_CLASS_NONE, None): if value: msg = "can't specify value with IOPRIO_CLASS_NONE" raise ValueError(msg) ioclass = IOPRIO_CLASS_NONE value = 0 if ioclass in (IOPRIO_CLASS_RT, IOPRIO_CLASS_BE): if value is None: value = 4 elif ioclass == IOPRIO_CLASS_IDLE: if value: msg = "can't specify value with IOPRIO_CLASS_IDLE" raise ValueError(msg) value = 0 else: value = 0 if not 0 <= value <= 8: raise ValueError( "value argument range expected is between 0 and 8") return cext.proc_ioprio_set(self.pid, ioclass, value) if HAS_PRLIMIT: @wrap_exceptions def rlimit(self, resource, limits=None): # if pid is 0 prlimit() applies to the calling process and # we don't want that if self.pid == 0: raise ValueError("can't use prlimit() against PID 0 process") if limits is None: # get return cext.linux_prlimit(self.pid, resource) else: # set if len(limits) != 2: raise ValueError( "second argument must be a (soft, hard) tuple") soft, hard = limits cext.linux_prlimit(self.pid, resource, soft, hard) @wrap_exceptions def status(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b"State:"): letter = line.split()[1] if PY3: letter = letter.decode() # XXX is '?' legit? (we're not supposed to return # it anyway) return PROC_STATUSES.get(letter, '?') @wrap_exceptions def open_files(self): retlist = [] files = os.listdir("/proc/%s/fd" % self.pid) hit_enoent = False for fd in files: file = "/proc/%s/fd/%s" % (self.pid, fd) try: file = os.readlink(file) except OSError as err: # ENOENT == file which is gone in the meantime if err.errno in (errno.ENOENT, errno.ESRCH): hit_enoent = True continue elif err.errno == errno.EINVAL: # not a link continue else: raise else: # If file is not an absolute path there's no way # to tell whether it's a regular file or not, # so we skip it. A regular file is always supposed # to be absolutized though. if file.startswith('/') and isfile_strict(file): ntuple = _common.popenfile(file, int(fd)) retlist.append(ntuple) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def connections(self, kind='inet'): ret = _connections.retrieve(kind, self.pid) # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return ret @wrap_exceptions def num_fds(self): return len(os.listdir("/proc/%s/fd" % self.pid)) @wrap_exceptions def ppid(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b"PPid:"): # PPid: nnnn return int(line.split()[1]) raise NotImplementedError("line not found") @wrap_exceptions def uids(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b'Uid:'): _, real, effective, saved, fs = line.split() return _common.puids(int(real), int(effective), int(saved)) raise NotImplementedError("line not found") @wrap_exceptions def gids(self): with open("/proc/%s/status" % self.pid, 'rb') as f: for line in f: if line.startswith(b'Gid:'): _, real, effective, saved, fs = line.split() return _common.pgids(int(real), int(effective), int(saved)) raise NotImplementedError("line not found") ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_psosx.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """OSX platform implementation.""" import errno import functools import os from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import conn_tmap, usage_percent, isfile_strict import _psutil_osx as cext import _psutil_posix __extra__all__ = [] # --- constants PAGESIZE = os.sysconf("SC_PAGE_SIZE") # http://students.mimuw.edu.pl/lxr/source/include/net/tcp_states.h TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } PROC_STATUSES = { cext.SIDL: _common.STATUS_IDLE, cext.SRUN: _common.STATUS_RUNNING, cext.SSLEEP: _common.STATUS_SLEEPING, cext.SSTOP: _common.STATUS_STOPPED, cext.SZOMB: _common.STATUS_ZOMBIE, } scputimes = namedtuple('scputimes', ['user', 'nice', 'system', 'idle']) svmem = namedtuple( 'svmem', ['total', 'available', 'percent', 'used', 'free', 'active', 'inactive', 'wired']) pextmem = namedtuple('pextmem', ['rss', 'vms', 'pfaults', 'pageins']) pmmap_grouped = namedtuple( 'pmmap_grouped', 'path rss private swapped dirtied ref_count shadow_depth') pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- functions def virtual_memory(): """System virtual memory as a namedtuple.""" total, active, inactive, wired, free = cext.virtual_mem() avail = inactive + free used = active + inactive + wired percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free, active, inactive, wired) def swap_memory(): """Swap system memory as a (total, used, free, sin, sout) tuple.""" total, used, free, sin, sout = cext.swap_mem() percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin, sout) def cpu_times(): """Return system CPU times as a namedtuple.""" user, nice, system, idle = cext.cpu_times() return scputimes(user, nice, system, idle) def per_cpu_times(): """Return system CPU times as a named tuple""" ret = [] for cpu_t in cext.per_cpu_times(): user, nice, system, idle = cpu_t item = scputimes(user, nice, system, idle) ret.append(item) return ret def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def disk_partitions(all=False): retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: if not os.path.isabs(device) or not os.path.exists(device): continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist def users(): retlist = [] rawlist = cext.users() for item in rawlist: user, tty, hostname, tstamp = item if tty == '~': continue # reboot or shutdown if not tstamp: continue nt = _common.suser(user, tty or None, hostname or None, tstamp) retlist.append(nt) return retlist def net_connections(kind='inet'): # Note: on OSX this will fail with AccessDenied unless # the process is owned by root. ret = [] for pid in pids(): try: cons = Process(pid).connections(kind) except NoSuchProcess: continue else: if cons: for c in cons: c = list(c) + [pid] ret.append(_common.sconn(*c)) return ret pids = cext.pids pid_exists = _psposix.pid_exists disk_usage = _psposix.disk_usage net_io_counters = cext.net_io_counters disk_io_counters = cext.disk_io_counters def wrap_exceptions(fun): """Decorator which translates bare OSError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except OSError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): return cext.proc_name(self.pid) @wrap_exceptions def exe(self): return cext.proc_exe(self.pid) @wrap_exceptions def cmdline(self): if not pid_exists(self.pid): raise NoSuchProcess(self.pid, self._name) return cext.proc_cmdline(self.pid) @wrap_exceptions def ppid(self): return cext.proc_ppid(self.pid) @wrap_exceptions def cwd(self): return cext.proc_cwd(self.pid) @wrap_exceptions def uids(self): real, effective, saved = cext.proc_uids(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def gids(self): real, effective, saved = cext.proc_gids(self.pid) return _common.pgids(real, effective, saved) @wrap_exceptions def terminal(self): tty_nr = cext.proc_tty_nr(self.pid) tmap = _psposix._get_terminal_map() try: return tmap[tty_nr] except KeyError: return None @wrap_exceptions def memory_info(self): rss, vms = cext.proc_memory_info(self.pid)[:2] return _common.pmem(rss, vms) @wrap_exceptions def memory_info_ex(self): rss, vms, pfaults, pageins = cext.proc_memory_info(self.pid) return pextmem(rss, vms, pfaults * PAGESIZE, pageins * PAGESIZE) @wrap_exceptions def cpu_times(self): user, system = cext.proc_cpu_times(self.pid) return _common.pcputimes(user, system) @wrap_exceptions def create_time(self): return cext.proc_create_time(self.pid) @wrap_exceptions def num_ctx_switches(self): return _common.pctxsw(*cext.proc_num_ctx_switches(self.pid)) @wrap_exceptions def num_threads(self): return cext.proc_num_threads(self.pid) @wrap_exceptions def open_files(self): if self.pid == 0: return [] files = [] rawlist = cext.proc_open_files(self.pid) for path, fd in rawlist: if isfile_strict(path): ntuple = _common.popenfile(path, fd) files.append(ntuple) return files @wrap_exceptions def connections(self, kind='inet'): if kind not in conn_tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in conn_tmap]))) families, types = conn_tmap[kind] rawlist = cext.proc_connections(self.pid, families, types) ret = [] for item in rawlist: fd, fam, type, laddr, raddr, status = item status = TCP_STATUSES[status] nt = _common.pconn(fd, fam, type, laddr, raddr, status) ret.append(nt) return ret @wrap_exceptions def num_fds(self): if self.pid == 0: return 0 return cext.proc_num_fds(self.pid) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) @wrap_exceptions def nice_get(self): return _psutil_posix.getpriority(self.pid) @wrap_exceptions def nice_set(self, value): return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def status(self): code = cext.proc_status(self.pid) # XXX is '?' legit? (we're not supposed to return it anyway) return PROC_STATUSES.get(code, '?') @wrap_exceptions def threads(self): rawlist = cext.proc_threads(self.pid) retlist = [] for thread_id, utime, stime in rawlist: ntuple = _common.pthread(thread_id, utime, stime) retlist.append(ntuple) return retlist @wrap_exceptions def memory_maps(self): return cext.proc_memory_maps(self.pid) ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_psposix.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Routines common to all posix systems.""" import errno import glob import os import sys import time from psutil._common import sdiskusage, usage_percent, memoize from psutil._compat import PY3, unicode class TimeoutExpired(Exception): pass def pid_exists(pid): """Check whether pid exists in the current process table.""" if pid == 0: # According to "man 2 kill" PID 0 has a special meaning: # it refers to <> so we don't want to go any further. # If we get here it means this UNIX platform *does* have # a process with id 0. return True try: os.kill(pid, 0) except OSError as err: if err.errno == errno.ESRCH: # ESRCH == No such process return False elif err.errno == errno.EPERM: # EPERM clearly means there's a process to deny access to return True else: # According to "man 2 kill" possible error values are # (EINVAL, EPERM, ESRCH) therefore we should never get # here. If we do let's be explicit in considering this # an error. raise err else: return True def wait_pid(pid, timeout=None): """Wait for process with pid 'pid' to terminate and return its exit status code as an integer. If pid is not a children of os.getpid() (current process) just waits until the process disappears and return None. If pid does not exist at all return None immediately. Raise TimeoutExpired on timeout expired. """ def check_timeout(delay): if timeout is not None: if timer() >= stop_at: raise TimeoutExpired() time.sleep(delay) return min(delay * 2, 0.04) timer = getattr(time, 'monotonic', time.time) if timeout is not None: waitcall = lambda: os.waitpid(pid, os.WNOHANG) stop_at = timer() + timeout else: waitcall = lambda: os.waitpid(pid, 0) delay = 0.0001 while True: try: retpid, status = waitcall() except OSError as err: if err.errno == errno.EINTR: delay = check_timeout(delay) continue elif err.errno == errno.ECHILD: # This has two meanings: # - pid is not a child of os.getpid() in which case # we keep polling until it's gone # - pid never existed in the first place # In both cases we'll eventually return None as we # can't determine its exit status code. while True: if pid_exists(pid): delay = check_timeout(delay) else: return else: raise else: if retpid == 0: # WNOHANG was used, pid is still running delay = check_timeout(delay) continue # process exited due to a signal; return the integer of # that signal if os.WIFSIGNALED(status): return os.WTERMSIG(status) # process exited using exit(2) system call; return the # integer exit(2) system call has been called with elif os.WIFEXITED(status): return os.WEXITSTATUS(status) else: # should never happen raise RuntimeError("unknown process exit status") def disk_usage(path): """Return disk usage associated with path.""" try: st = os.statvfs(path) except UnicodeEncodeError: if not PY3 and isinstance(path, unicode): # this is a bug with os.statvfs() and unicode on # Python 2, see: # - https://github.com/giampaolo/psutil/issues/416 # - http://bugs.python.org/issue18695 try: path = path.encode(sys.getfilesystemencoding()) except UnicodeEncodeError: pass st = os.statvfs(path) else: raise free = (st.f_bavail * st.f_frsize) total = (st.f_blocks * st.f_frsize) used = (st.f_blocks - st.f_bfree) * st.f_frsize percent = usage_percent(used, total, _round=1) # NB: the percentage is -5% than what shown by df due to # reserved blocks that we are currently not considering: # http://goo.gl/sWGbH return sdiskusage(total, used, free, percent) @memoize def _get_terminal_map(): ret = {} ls = glob.glob('/dev/tty*') + glob.glob('/dev/pts/*') for name in ls: assert name not in ret try: ret[os.stat(name).st_rdev] = name except OSError as err: if err.errno != errno.ENOENT: raise return ret ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_pssunos.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Sun OS Solaris platform implementation.""" import errno import os import socket import subprocess import sys from collections import namedtuple from psutil import _common from psutil import _psposix from psutil._common import usage_percent, isfile_strict from psutil._compat import PY3 import _psutil_posix import _psutil_sunos as cext __extra__all__ = ["CONN_IDLE", "CONN_BOUND"] PAGE_SIZE = os.sysconf('SC_PAGE_SIZE') CONN_IDLE = "IDLE" CONN_BOUND = "BOUND" PROC_STATUSES = { cext.SSLEEP: _common.STATUS_SLEEPING, cext.SRUN: _common.STATUS_RUNNING, cext.SZOMB: _common.STATUS_ZOMBIE, cext.SSTOP: _common.STATUS_STOPPED, cext.SIDL: _common.STATUS_IDLE, cext.SONPROC: _common.STATUS_RUNNING, # same as run cext.SWAIT: _common.STATUS_WAITING, } TCP_STATUSES = { cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED, cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT, cext.TCPS_SYN_RCVD: _common.CONN_SYN_RECV, cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1, cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2, cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT, cext.TCPS_CLOSED: _common.CONN_CLOSE, cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK, cext.TCPS_LISTEN: _common.CONN_LISTEN, cext.TCPS_CLOSING: _common.CONN_CLOSING, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, cext.TCPS_IDLE: CONN_IDLE, # sunos specific cext.TCPS_BOUND: CONN_BOUND, # sunos specific } scputimes = namedtuple('scputimes', ['user', 'system', 'idle', 'iowait']) svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free']) pextmem = namedtuple('pextmem', ['rss', 'vms']) pmmap_grouped = namedtuple('pmmap_grouped', ['path', 'rss', 'anon', 'locked']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None # --- functions disk_io_counters = cext.disk_io_counters net_io_counters = cext.net_io_counters disk_usage = _psposix.disk_usage def virtual_memory(): # we could have done this with kstat, but imho this is good enough total = os.sysconf('SC_PHYS_PAGES') * PAGE_SIZE # note: there's no difference on Solaris free = avail = os.sysconf('SC_AVPHYS_PAGES') * PAGE_SIZE used = total - free percent = usage_percent(used, total, _round=1) return svmem(total, avail, percent, used, free) def swap_memory(): sin, sout = cext.swap_mem() # XXX # we are supposed to get total/free by doing so: # http://cvs.opensolaris.org/source/xref/onnv/onnv-gate/ # usr/src/cmd/swap/swap.c # ...nevertheless I can't manage to obtain the same numbers as 'swap' # cmdline utility, so let's parse its output (sigh!) p = subprocess.Popen(['swap', '-l', '-k'], stdout=subprocess.PIPE) stdout, stderr = p.communicate() if PY3: stdout = stdout.decode(sys.stdout.encoding) if p.returncode != 0: raise RuntimeError("'swap -l -k' failed (retcode=%s)" % p.returncode) lines = stdout.strip().split('\n')[1:] if not lines: raise RuntimeError('no swap device(s) configured') total = free = 0 for line in lines: line = line.split() t, f = line[-2:] t = t.replace('K', '') f = f.replace('K', '') total += int(int(t) * 1024) free += int(int(f) * 1024) used = total - free percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, sin * PAGE_SIZE, sout * PAGE_SIZE) def pids(): """Returns a list of PIDs currently running on the system.""" return [int(x) for x in os.listdir('/proc') if x.isdigit()] def pid_exists(pid): """Check for the existence of a unix pid.""" return _psposix.pid_exists(pid) def cpu_times(): """Return system-wide CPU times as a named tuple""" ret = cext.per_cpu_times() return scputimes(*[sum(x) for x in zip(*ret)]) def per_cpu_times(): """Return system per-CPU times as a list of named tuples""" ret = cext.per_cpu_times() return [scputimes(*x) for x in ret] def cpu_count_logical(): """Return the number of logical CPUs in the system.""" try: return os.sysconf("SC_NPROCESSORS_ONLN") except ValueError: # mimic os.cpu_count() behavior return None def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() localhost = (':0.0', ':0') for item in rawlist: user, tty, hostname, tstamp, user_process = item # note: the underlying C function includes entries about # system boot, run level and others. We might want # to use them in the future. if not user_process: continue if hostname in localhost: hostname = 'localhost' nt = _common.suser(user, tty, hostname, tstamp) retlist.append(nt) return retlist def disk_partitions(all=False): """Return system disk partitions.""" # TODO - the filtering logic should be better checked so that # it tries to reflect 'df' as much as possible retlist = [] partitions = cext.disk_partitions() for partition in partitions: device, mountpoint, fstype, opts = partition if device == 'none': device = '' if not all: # Differently from, say, Linux, we don't have a list of # common fs types so the best we can do, AFAIK, is to # filter by filesystem having a total size > 0. if not disk_usage(mountpoint).total: continue ntuple = _common.sdiskpart(device, mountpoint, fstype, opts) retlist.append(ntuple) return retlist def net_connections(kind, _pid=-1): """Return socket connections. If pid == -1 return system-wide connections (as opposed to connections opened by one process only). Only INET sockets are returned (UNIX are not). """ cmap = _common.conn_tmap.copy() if _pid == -1: cmap.pop('unix', 0) if kind not in cmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in cmap]))) families, types = _common.conn_tmap[kind] rawlist = cext.net_connections(_pid, families, types) ret = [] for item in rawlist: fd, fam, type_, laddr, raddr, status, pid = item if fam not in families: continue if type_ not in types: continue status = TCP_STATUSES[status] if _pid == -1: nt = _common.sconn(fd, fam, type_, laddr, raddr, status, pid) else: nt = _common.pconn(fd, fam, type_, laddr, raddr, status) ret.append(nt) return ret def wrap_exceptions(fun): """Call callable into a try/except clause and translate ENOENT, EACCES and EPERM in NoSuchProcess or AccessDenied exceptions. """ def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except EnvironmentError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise # ENOENT (no such file or directory) gets raised on open(). # ESRCH (no such process) can get raised on read() if # process is gone in meantime. if err.errno in (errno.ENOENT, errno.ESRCH): raise NoSuchProcess(self.pid, self._name) if err.errno in (errno.EPERM, errno.EACCES): raise AccessDenied(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): # note: max len == 15 return cext.proc_name_and_args(self.pid)[0] @wrap_exceptions def exe(self): # Will be guess later from cmdline but we want to explicitly # invoke cmdline here in order to get an AccessDenied # exception if the user has not enough privileges. self.cmdline() return "" @wrap_exceptions def cmdline(self): return cext.proc_name_and_args(self.pid)[1].split(' ') @wrap_exceptions def create_time(self): return cext.proc_basic_info(self.pid)[3] @wrap_exceptions def num_threads(self): return cext.proc_basic_info(self.pid)[5] @wrap_exceptions def nice_get(self): # For some reason getpriority(3) return ESRCH (no such process) # for certain low-pid processes, no matter what (even as root). # The process actually exists though, as it has a name, # creation time, etc. # The best thing we can do here appears to be raising AD. # Note: tested on Solaris 11; on Open Solaris 5 everything is # fine. try: return _psutil_posix.getpriority(self.pid) except EnvironmentError as err: if err.errno in (errno.ENOENT, errno.ESRCH): if pid_exists(self.pid): raise AccessDenied(self.pid, self._name) raise @wrap_exceptions def nice_set(self, value): if self.pid in (2, 3): # Special case PIDs: internally setpriority(3) return ESRCH # (no such process), no matter what. # The process actually exists though, as it has a name, # creation time, etc. raise AccessDenied(self.pid, self._name) return _psutil_posix.setpriority(self.pid, value) @wrap_exceptions def ppid(self): return cext.proc_basic_info(self.pid)[0] @wrap_exceptions def uids(self): real, effective, saved, _, _, _ = cext.proc_cred(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def gids(self): _, _, _, real, effective, saved = cext.proc_cred(self.pid) return _common.puids(real, effective, saved) @wrap_exceptions def cpu_times(self): user, system = cext.proc_cpu_times(self.pid) return _common.pcputimes(user, system) @wrap_exceptions def terminal(self): hit_enoent = False tty = wrap_exceptions( cext.proc_basic_info(self.pid)[0]) if tty != cext.PRNODEV: for x in (0, 1, 2, 255): try: return os.readlink('/proc/%d/path/%d' % (self.pid, x)) except OSError as err: if err.errno == errno.ENOENT: hit_enoent = True continue raise if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) @wrap_exceptions def cwd(self): # /proc/PID/path/cwd may not be resolved by readlink() even if # it exists (ls shows it). If that's the case and the process # is still alive return None (we can return None also on BSD). # Reference: http://goo.gl/55XgO try: return os.readlink("/proc/%s/path/cwd" % self.pid) except OSError as err: if err.errno == errno.ENOENT: os.stat("/proc/%s" % self.pid) return None raise @wrap_exceptions def memory_info(self): ret = cext.proc_basic_info(self.pid) rss, vms = ret[1] * 1024, ret[2] * 1024 return _common.pmem(rss, vms) # it seems Solaris uses rss and vms only memory_info_ex = memory_info @wrap_exceptions def status(self): code = cext.proc_basic_info(self.pid)[6] # XXX is '?' legit? (we're not supposed to return it anyway) return PROC_STATUSES.get(code, '?') @wrap_exceptions def threads(self): ret = [] tids = os.listdir('/proc/%d/lwp' % self.pid) hit_enoent = False for tid in tids: tid = int(tid) try: utime, stime = cext.query_process_thread( self.pid, tid) except EnvironmentError as err: # ENOENT == thread gone in meantime if err.errno == errno.ENOENT: hit_enoent = True continue raise else: nt = _common.pthread(tid, utime, stime) ret.append(nt) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return ret @wrap_exceptions def open_files(self): retlist = [] hit_enoent = False pathdir = '/proc/%d/path' % self.pid for fd in os.listdir('/proc/%d/fd' % self.pid): path = os.path.join(pathdir, fd) if os.path.islink(path): try: file = os.readlink(path) except OSError as err: # ENOENT == file which is gone in the meantime if err.errno == errno.ENOENT: hit_enoent = True continue raise else: if isfile_strict(file): retlist.append(_common.popenfile(file, int(fd))) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist def _get_unix_sockets(self, pid): """Get UNIX sockets used by process by parsing 'pfiles' output.""" # TODO: rewrite this in C (...but the damn netstat source code # does not include this part! Argh!!) cmd = "pfiles %s" % pid p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = p.communicate() if PY3: stdout, stderr = [x.decode(sys.stdout.encoding) for x in (stdout, stderr)] if p.returncode != 0: if 'permission denied' in stderr.lower(): raise AccessDenied(self.pid, self._name) if 'no such process' in stderr.lower(): raise NoSuchProcess(self.pid, self._name) raise RuntimeError("%r command error\n%s" % (cmd, stderr)) lines = stdout.split('\n')[2:] for i, line in enumerate(lines): line = line.lstrip() if line.startswith('sockname: AF_UNIX'): path = line.split(' ', 2)[2] type = lines[i - 2].strip() if type == 'SOCK_STREAM': type = socket.SOCK_STREAM elif type == 'SOCK_DGRAM': type = socket.SOCK_DGRAM else: type = -1 yield (-1, socket.AF_UNIX, type, path, "", _common.CONN_NONE) @wrap_exceptions def connections(self, kind='inet'): ret = net_connections(kind, _pid=self.pid) # The underlying C implementation retrieves all OS connections # and filters them by PID. At this point we can't tell whether # an empty list means there were no connections for process or # process is no longer active so we force NSP in case the PID # is no longer there. if not ret: os.stat('/proc/%s' % self.pid) # will raise NSP if process is gone # UNIX sockets if kind in ('all', 'unix'): ret.extend([_common.pconn(*conn) for conn in self._get_unix_sockets(self.pid)]) return ret nt_mmap_grouped = namedtuple('mmap', 'path rss anon locked') nt_mmap_ext = namedtuple('mmap', 'addr perms path rss anon locked') @wrap_exceptions def memory_maps(self): def toaddr(start, end): return '%s-%s' % (hex(start)[2:].strip('L'), hex(end)[2:].strip('L')) retlist = [] rawlist = cext.proc_memory_maps(self.pid) hit_enoent = False for item in rawlist: addr, addrsize, perm, name, rss, anon, locked = item addr = toaddr(addr, addrsize) if not name.startswith('['): try: name = os.readlink('/proc/%s/path/%s' % (self.pid, name)) except OSError as err: if err.errno == errno.ENOENT: # sometimes the link may not be resolved by # readlink() even if it exists (ls shows it). # If that's the case we just return the # unresolved link path. # This seems an incosistency with /proc similar # to: http://goo.gl/55XgO name = '/proc/%s/path/%s' % (self.pid, name) hit_enoent = True else: raise retlist.append((addr, perm, name, rss, anon, locked)) if hit_enoent: # raise NSP if the process disappeared on us os.stat('/proc/%s' % self.pid) return retlist @wrap_exceptions def num_fds(self): return len(os.listdir("/proc/%s/fd" % self.pid)) @wrap_exceptions def num_ctx_switches(self): return _common.pctxsw(*cext.proc_num_ctx_switches(self.pid)) @wrap_exceptions def wait(self, timeout=None): try: return _psposix.wait_pid(self.pid, timeout) except _psposix.TimeoutExpired: # support for private module import if TimeoutExpired is None: raise raise TimeoutExpired(timeout, self.pid, self._name) ================================================ FILE: Common/libpsutil/py2.7-glibc-2.12+/psutil/_pswindows.py ================================================ #!/usr/bin/env python # Copyright (c) 2009, Giampaolo Rodola'. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Windows platform implementation.""" import errno import functools import os from collections import namedtuple from psutil import _common from psutil._common import conn_tmap, usage_percent, isfile_strict from psutil._compat import PY3, xrange, lru_cache import _psutil_windows as cext # process priority constants, import from __init__.py: # http://msdn.microsoft.com/en-us/library/ms686219(v=vs.85).aspx __extra__all__ = ["ABOVE_NORMAL_PRIORITY_CLASS", "BELOW_NORMAL_PRIORITY_CLASS", "HIGH_PRIORITY_CLASS", "IDLE_PRIORITY_CLASS", "NORMAL_PRIORITY_CLASS", "REALTIME_PRIORITY_CLASS", # "CONN_DELETE_TCB", ] # --- module level constants (gets pushed up to psutil module) CONN_DELETE_TCB = "DELETE_TCB" WAIT_TIMEOUT = 0x00000102 # 258 in decimal ACCESS_DENIED_SET = frozenset([errno.EPERM, errno.EACCES, cext.ERROR_ACCESS_DENIED]) TCP_STATUSES = { cext.MIB_TCP_STATE_ESTAB: _common.CONN_ESTABLISHED, cext.MIB_TCP_STATE_SYN_SENT: _common.CONN_SYN_SENT, cext.MIB_TCP_STATE_SYN_RCVD: _common.CONN_SYN_RECV, cext.MIB_TCP_STATE_FIN_WAIT1: _common.CONN_FIN_WAIT1, cext.MIB_TCP_STATE_FIN_WAIT2: _common.CONN_FIN_WAIT2, cext.MIB_TCP_STATE_TIME_WAIT: _common.CONN_TIME_WAIT, cext.MIB_TCP_STATE_CLOSED: _common.CONN_CLOSE, cext.MIB_TCP_STATE_CLOSE_WAIT: _common.CONN_CLOSE_WAIT, cext.MIB_TCP_STATE_LAST_ACK: _common.CONN_LAST_ACK, cext.MIB_TCP_STATE_LISTEN: _common.CONN_LISTEN, cext.MIB_TCP_STATE_CLOSING: _common.CONN_CLOSING, cext.MIB_TCP_STATE_DELETE_TCB: CONN_DELETE_TCB, cext.PSUTIL_CONN_NONE: _common.CONN_NONE, } scputimes = namedtuple('scputimes', ['user', 'system', 'idle']) svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free']) pextmem = namedtuple( 'pextmem', ['num_page_faults', 'peak_wset', 'wset', 'peak_paged_pool', 'paged_pool', 'peak_nonpaged_pool', 'nonpaged_pool', 'pagefile', 'peak_pagefile', 'private']) pmmap_grouped = namedtuple('pmmap_grouped', ['path', 'rss']) pmmap_ext = namedtuple( 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields)) # set later from __init__.py NoSuchProcess = None AccessDenied = None TimeoutExpired = None @lru_cache(maxsize=512) def _win32_QueryDosDevice(s): return cext.win32_QueryDosDevice(s) def _convert_raw_path(s): # convert paths using native DOS format like: # "\Device\HarddiskVolume1\Windows\systemew\file.txt" # into: "C:\Windows\systemew\file.txt" if PY3 and not isinstance(s, str): s = s.decode('utf8') rawdrive = '\\'.join(s.split('\\')[:3]) driveletter = _win32_QueryDosDevice(rawdrive) return os.path.join(driveletter, s[len(rawdrive):]) # --- public functions def virtual_memory(): """System virtual memory as a namedtuple.""" mem = cext.virtual_mem() totphys, availphys, totpagef, availpagef, totvirt, freevirt = mem # total = totphys avail = availphys free = availphys used = total - avail percent = usage_percent((total - avail), total, _round=1) return svmem(total, avail, percent, used, free) def swap_memory(): """Swap system memory as a (total, used, free, sin, sout) tuple.""" mem = cext.virtual_mem() total = mem[2] free = mem[3] used = total - free percent = usage_percent(used, total, _round=1) return _common.sswap(total, used, free, percent, 0, 0) def disk_usage(path): """Return disk usage associated with path.""" try: total, free = cext.disk_usage(path) except WindowsError: if not os.path.exists(path): msg = "No such file or directory: '%s'" % path raise OSError(errno.ENOENT, msg) raise used = total - free percent = usage_percent(used, total, _round=1) return _common.sdiskusage(total, used, free, percent) def disk_partitions(all): """Return disk partitions.""" rawlist = cext.disk_partitions(all) return [_common.sdiskpart(*x) for x in rawlist] def cpu_times(): """Return system CPU times as a named tuple.""" user, system, idle = cext.cpu_times() return scputimes(user, system, idle) def per_cpu_times(): """Return system per-CPU times as a list of named tuples.""" ret = [] for cpu_t in cext.per_cpu_times(): user, system, idle = cpu_t item = scputimes(user, system, idle) ret.append(item) return ret def cpu_count_logical(): """Return the number of logical CPUs in the system.""" return cext.cpu_count_logical() def cpu_count_physical(): """Return the number of physical CPUs in the system.""" return cext.cpu_count_phys() def boot_time(): """The system boot time expressed in seconds since the epoch.""" return cext.boot_time() def net_connections(kind, _pid=-1): """Return socket connections. If pid == -1 return system-wide connections (as opposed to connections opened by one process only). """ if kind not in conn_tmap: raise ValueError("invalid %r kind argument; choose between %s" % (kind, ', '.join([repr(x) for x in conn_tmap]))) families, types = conn_tmap[kind] rawlist = cext.net_connections(_pid, families, types) ret = [] for item in rawlist: fd, fam, type, laddr, raddr, status, pid = item status = TCP_STATUSES[status] if _pid == -1: nt = _common.sconn(fd, fam, type, laddr, raddr, status, pid) else: nt = _common.pconn(fd, fam, type, laddr, raddr, status) ret.append(nt) return ret def users(): """Return currently connected users as a list of namedtuples.""" retlist = [] rawlist = cext.users() for item in rawlist: user, hostname, tstamp = item nt = _common.suser(user, None, hostname, tstamp) retlist.append(nt) return retlist pids = cext.pids pid_exists = cext.pid_exists net_io_counters = cext.net_io_counters disk_io_counters = cext.disk_io_counters ppid_map = cext.ppid_map # not meant to be public def wrap_exceptions(fun): """Decorator which translates bare OSError and WindowsError exceptions into NoSuchProcess and AccessDenied. """ @functools.wraps(fun) def wrapper(self, *args, **kwargs): try: return fun(self, *args, **kwargs) except OSError as err: # support for private module import if NoSuchProcess is None or AccessDenied is None: raise if err.errno in ACCESS_DENIED_SET: raise AccessDenied(self.pid, self._name) if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) raise return wrapper class Process(object): """Wrapper class around underlying C implementation.""" __slots__ = ["pid", "_name"] def __init__(self, pid): self.pid = pid self._name = None @wrap_exceptions def name(self): """Return process name, which on Windows is always the final part of the executable. """ # This is how PIDs 0 and 4 are always represented in taskmgr # and process-hacker. if self.pid == 0: return "System Idle Process" elif self.pid == 4: return "System" else: return os.path.basename(self.exe()) @wrap_exceptions def exe(self): # Note: os.path.exists(path) may return False even if the file # is there, see: # http://stackoverflow.com/questions/3112546/os-path-exists-lies # see https://github.com/giampaolo/psutil/issues/414 # see https://github.com/giampaolo/psutil/issues/528 if self.pid in (0, 4): raise AccessDenied(self.pid, self._name) return _convert_raw_path(cext.proc_exe(self.pid)) @wrap_exceptions def cmdline(self): return cext.proc_cmdline(self.pid) def ppid(self): try: return ppid_map()[self.pid] except KeyError: raise NoSuchProcess(self.pid, self._name) def _get_raw_meminfo(self): try: return cext.proc_memory_info(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_memory_info_2(self.pid) raise @wrap_exceptions def memory_info(self): # on Windows RSS == WorkingSetSize and VSM == PagefileUsage # fields of PROCESS_MEMORY_COUNTERS struct: # http://msdn.microsoft.com/en-us/library/windows/desktop/ # ms684877(v=vs.85).aspx t = self._get_raw_meminfo() return _common.pmem(t[2], t[7]) @wrap_exceptions def memory_info_ex(self): return pextmem(*self._get_raw_meminfo()) def memory_maps(self): try: raw = cext.proc_memory_maps(self.pid) except OSError as err: # XXX - can't use wrap_exceptions decorator as we're # returning a generator; probably needs refactoring. if err.errno in ACCESS_DENIED_SET: raise AccessDenied(self.pid, self._name) if err.errno == errno.ESRCH: raise NoSuchProcess(self.pid, self._name) raise else: for addr, perm, path, rss in raw: path = _convert_raw_path(path) addr = hex(addr) yield (addr, perm, path, rss) @wrap_exceptions def kill(self): return cext.proc_kill(self.pid) @wrap_exceptions def wait(self, timeout=None): if timeout is None: timeout = cext.INFINITE else: # WaitForSingleObject() expects time in milliseconds timeout = int(timeout * 1000) ret = cext.proc_wait(self.pid, timeout) if ret == WAIT_TIMEOUT: # support for private module import if TimeoutExpired is None: raise RuntimeError("timeout expired") raise TimeoutExpired(timeout, self.pid, self._name) return ret @wrap_exceptions def username(self): if self.pid in (0, 4): return 'NT AUTHORITY\\SYSTEM' return cext.proc_username(self.pid) @wrap_exceptions def create_time(self): # special case for kernel process PIDs; return system boot time if self.pid in (0, 4): return boot_time() try: return cext.proc_create_time(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_create_time_2(self.pid) raise @wrap_exceptions def num_threads(self): return cext.proc_num_threads(self.pid) @wrap_exceptions def threads(self): rawlist = cext.proc_threads(self.pid) retlist = [] for thread_id, utime, stime in rawlist: ntuple = _common.pthread(thread_id, utime, stime) retlist.append(ntuple) return retlist @wrap_exceptions def cpu_times(self): try: ret = cext.proc_cpu_times(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: ret = cext.proc_cpu_times_2(self.pid) else: raise return _common.pcputimes(*ret) @wrap_exceptions def suspend(self): return cext.proc_suspend(self.pid) @wrap_exceptions def resume(self): return cext.proc_resume(self.pid) @wrap_exceptions def cwd(self): if self.pid in (0, 4): raise AccessDenied(self.pid, self._name) # return a normalized pathname since the native C function appends # "\\" at the and of the path path = cext.proc_cwd(self.pid) return os.path.normpath(path) @wrap_exceptions def open_files(self): if self.pid in (0, 4): return [] retlist = [] # Filenames come in in native format like: # "\Device\HarddiskVolume1\Windows\systemew\file.txt" # Convert the first part in the corresponding drive letter # (e.g. "C:\") by using Windows's QueryDosDevice() raw_file_names = cext.proc_open_files(self.pid) for file in raw_file_names: file = _convert_raw_path(file) if isfile_strict(file) and file not in retlist: ntuple = _common.popenfile(file, -1) retlist.append(ntuple) return retlist @wrap_exceptions def connections(self, kind='inet'): return net_connections(kind, _pid=self.pid) @wrap_exceptions def nice_get(self): return cext.proc_priority_get(self.pid) @wrap_exceptions def nice_set(self, value): return cext.proc_priority_set(self.pid, value) # available on Windows >= Vista if hasattr(cext, "proc_io_priority_get"): @wrap_exceptions def ionice_get(self): return cext.proc_io_priority_get(self.pid) @wrap_exceptions def ionice_set(self, value, _): if _: raise TypeError("set_proc_ionice() on Windows takes only " "1 argument (2 given)") if value not in (2, 1, 0): raise ValueError("value must be 2 (normal), 1 (low) or 0 " "(very low); got %r" % value) return cext.proc_io_priority_set(self.pid, value) @wrap_exceptions def io_counters(self): try: ret = cext.proc_io_counters(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: ret = cext.proc_io_counters_2(self.pid) else: raise return _common.pio(*ret) @wrap_exceptions def status(self): suspended = cext.proc_is_suspended(self.pid) if suspended: return _common.STATUS_STOPPED else: return _common.STATUS_RUNNING @wrap_exceptions def cpu_affinity_get(self): from_bitmask = lambda x: [i for i in xrange(64) if (1 << i) & x] bitmask = cext.proc_cpu_affinity_get(self.pid) return from_bitmask(bitmask) @wrap_exceptions def cpu_affinity_set(self, value): def to_bitmask(l): if not l: raise ValueError("invalid argument %r" % l) out = 0 for b in l: out |= 2 ** b return out # SetProcessAffinityMask() states that ERROR_INVALID_PARAMETER # is returned for an invalid CPU but this seems not to be true, # therefore we check CPUs validy beforehand. allcpus = list(range(len(per_cpu_times()))) for cpu in value: if cpu not in allcpus: raise ValueError("invalid CPU %r" % cpu) bitmask = to_bitmask(value) cext.proc_cpu_affinity_set(self.pid, bitmask) @wrap_exceptions def num_handles(self): try: return cext.proc_num_handles(self.pid) except OSError as err: if err.errno in ACCESS_DENIED_SET: return cext.proc_num_handles_2(self.pid) raise @wrap_exceptions def num_ctx_switches(self): tupl = cext.proc_num_ctx_switches(self.pid) return _common.pctxsw(*tupl) ================================================ FILE: Common/waagentloader.py ================================================ # Wrapper module for waagent # # waagent is not written as a module. This wrapper module is created # to use the waagent code as a module. # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os def load_waagent(path=None): if path is None: pwd = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(pwd, 'waagent') waagent = None if sys.version_info >= (3, 12): import importlib.util spec = importlib.util.spec_from_file_location('waagent', path) waagent = importlib.util.module_from_spec(spec) spec.loader.exec_module(waagent) else: import imp waagent = imp.load_source('waagent', path) waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.MyDistro = waagent.GetMyDistro() waagent.Config = waagent.ConfigurationProvider(None) return waagent ================================================ FILE: CustomScript/CHANGELOG.md ================================================ ## vNext (yyyy-mm-dd) - Error message misleading [#150] - Fix for internal DNS check [#98] ## 1.5.2.0 (2016-04-11) - Fix state machine for status transitions. [#119] ## 1.5.1.0 (2016-04-05) - Atomically write the status file. [#117] ## 1.5.0.0 (2016-03-23) - Refactor CustomScript and add LogUtil & ScriptUtil - Refine MDS enents to log which file the extension fails to download - Do not log `commandToExecute` to `extension.log` if it's passed by protectedSettings ## 1.4.1.0 (2015-12-21) - Move downloading scripts and internal DNS check into the daemon process - Provide an option to disable internal DNS check - Add a timeout to urllib2.urlopen() ## 1.4.0.0 (2015-11-19) - Protect sensitive data in `commandToExecute` ================================================ FILE: CustomScript/HandlerManifest.json ================================================ [ { "version": 1.0, "handlerManifest": { "disableCommand": "shim.sh -disable", "enableCommand": "shim.sh -enable", "installCommand": "shim.sh -install", "uninstallCommand": "shim.sh -uninstall", "updateCommand": "shim.sh -update", "rebootAfterInstall": false, "reportHeartbeat": false } } ] ================================================ FILE: CustomScript/README.md ================================================ # CustomScript Extension Allow the owner of the Azure Virtual Machines to run customized scripts in the VM. # :warning: New Version Notice :warning: A new version of **Custom Script Extension** is available at https://github.com/Azure/custom-script-extension-linux. The new `v2.0` version offers better reliability and wider Linux distro support. Please consider switching your new deployments to use the new version (`Microsoft.Azure.Extensions.CustomScript`) instead. The new version is intended to be a drop-in replacement. Therefore, the migration is as easy as changing the name and version, you do not need to change your extension configuration. ----------------------------- This user guide is for `Microsoft.OSTCExtensions.CustomScript` extension. You can read the User Guide below. * [Automate Linux VM Customization Tasks Using CustomScript Extension (outdated, needs to update)](https://azure.microsoft.com/en-us/blog/automate-linux-vm-customization-tasks-using-customscript-extension/) CustomScript Extension can: * If provided, download the customized scripts from Azure Storage or external public storage (e.g. Github) * Run the entrypoint script * Support inline command * Convert Windows style newline in Shell and Python scripts automatically * Remove BOM in Shell and Python scripts automatically * Protect sensitive data in `commandToExecute` **Note:** The timeout for script download is 200 seconds. There is no timeout period for script execution. # User Guide ## 1. Configuration schema ### 1.1. Public configuration Schema for the public configuration file looks like this: * `fileUris`: (optional, string array) the uri list of the scripts * `commandToExecute`: (required, string) the entrypoint script to execute * `enableInternalDNSCheck`: (optional, bool) default is True, set to False to disable DNS check. ```json { "fileUris": [""], "commandToExecute": "" } ``` ### 1.2. Protected configuration Schema for the protected configuration file looks like this: * `commandToExecute`: (optional, string) the entrypoint script to execute * `storageAccountName`: (optional, string) the name of storage account * `storageAccountKey`: (optional, string) the access key of storage account ```json { "commandToExecute": "", "storageAccountName": "", "storageAccountKey": "" } ``` **NOTE:** 1. The storage account here is to store the scripts in `fileUris`. If the scripts are stored in the private Azure Storage, you should provide `storageAccountName` and `storageAccountKey`. You can get these two values from Azure Portal. *Currently only general purpose storage accounts are supported. We intend to add support for the new [Azure Cool Blob Storage](https://azure.microsoft.com/en-us/blog/introducing-azure-cool-storage/) in the near future. See #161* 2. `commandToExecute` in protected settings can protect your sensitive data. But `commandToExecute` should not be specified both in public and protected configurations. ## 2. Deploying the Extension to a VM You can deploy it using Azure CLI, Azure Powershell and ARM template. **NOTE:** Creating VM in Azure has two deployment model: Classic and [Resource Manager][arm-overview]. In different models, the deploy commands have different syntaxes. Please select the right one in section 2.1 and 2.2 below. ### 2.1. Using [**Azure CLI**][azure-cli] Before deploying CustomScript Extension, you should configure your `public.json` and `protected.json` (in section 1.1 and 1.2 above). #### 2.1.1 Classic The Classic mode is also called Azure Service Management mode. You can change to it by running: ``` $ azure config mode asm ``` You can deploy CustomScript Extension by running: ``` $ azure vm extension set \ CustomScriptForLinux Microsoft.OSTCExtensions \ --public-config-path public.json \ --private-config-path protected.json ``` In the command above, you can change version with `'*'` to use latest version available, or `'1.*'` to get newest version that does not introduce breaking schema changes. To learn the latest version available, run: ``` $ azure vm extension list ``` You can also omit `--private-config-path` if you do not want to configure those settings. #### 2.1.2 Resource Manager You can change to Azure Resource Manager mode by running: ``` $ azure config mode arm ``` You can deploy CustomScript Extension by running: ``` $ azure vm extension set \ CustomScriptForLinux Microsoft.OSTCExtensions \ --public-config-path public.json \ --private-config-path protected.json ``` > **NOTE:** In ARM mode, `azure vm extension list` is not available for now. ### 2.2. Using [**Azure Powershell**][azure-powershell] #### 2.2.1 Classic You can login to your Azure account (Azure Service Management mode) by running: ```powershell Add-AzureAccount ``` You can deploy CustomScript Extension by running: ```powershell $VmName = '' $vm = Get-AzureVM -ServiceName $VmName -Name $VmName $ExtensionName = 'CustomScriptForLinux' $Publisher = 'Microsoft.OSTCExtensions' $Version = '' $PublicConf = '{ "fileUris": [""], "commandToExecute": "" }' $PrivateConf = '{ "storageAccountName": "", "storageAccountKey": "" }' Set-AzureVMExtension -ExtensionName $ExtensionName -VM $vm ` -Publisher $Publisher -Version $Version ` -PrivateConfiguration $PrivateConf -PublicConfiguration $PublicConf | Update-AzureVM ``` #### 2.2.2 Resource Manager You can login to your Azure account (Azure Resource Manager mode) by running: ```powershell Login-AzureRmAccount ``` Click [**HERE**](https://azure.microsoft.com/en-us/documentation/articles/powershell-azure-resource-manager/) to learn more about how to use Azure Powershell with Azure Resource Manager. You can deploy CustomScript Extension by running: ```powershell $RGName = '' $VmName = '' $Location = '' $ExtensionName = 'CustomScriptForLinux' $Publisher = 'Microsoft.OSTCExtensions' $Version = '' $PublicConf = '{ "fileUris": [""], "commandToExecute": "" }' $PrivateConf = '{ "storageAccountName": "", "storageAccountKey": "" }' Set-AzureRmVMExtension -ResourceGroupName $RGName -VMName $VmName -Location $Location ` -Name $ExtensionName -Publisher $Publisher ` -ExtensionType $ExtensionName -TypeHandlerVersion $Version ` -Settingstring $PublicConf -ProtectedSettingString $PrivateConf ``` ### 2.3. Using [**ARM Template**][arm-template] ```json { "type": "Microsoft.Compute/virtualMachines/extensions", "name": "", "apiVersion": "", "location": "", "dependsOn": [ "[concat('Microsoft.Compute/virtualMachines/', )]" ], "properties": { "publisher": "Microsoft.OSTCExtensions", "type": "CustomScriptForLinux", "typeHandlerVersion": "1.5", "autoUpgradeMinorVersion": true, "settings": { "fileUris": [ "" ], "commandToExecute": "" }, "protectedSettings": { "storageAccountName": "", "storageAccountKey": "" } } } ``` There are two sample templates in [Azure/azure-quickstart-templates](https://github.com/Azure/azure-quickstart-templates). * [201-customscript-extension-public-storage-on-ubuntu](https://github.com/Azure/azure-quickstart-templates/tree/master/201-customscript-extension-public-storage-on-ubuntu) * [201-customscript-extension-azure-storage-on-ubuntu](https://github.com/Azure/azure-quickstart-templates/tree/master/201-customscript-extension-azure-storage-on-ubuntu) For more details about ARM template, please visit [Authoring Azure Resource Manager templates](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authoring-templates/). ## 3. Scenarios ### 3.1 Run scripts stored in Azure Storage * Public configuration ```json { "fileUris": ["http://MyAccount.blob.core.windows.net/vhds/MyShellScript.sh"], "commandToExecute": " sh MyShellScript.sh" } ``` * Protected configuration ```json { "storageAccountName": "MyAccount", "storageAccountKey": "Mykey" } ``` ### 3.2 Run scripts stored in GitHub * Public configuration ```json { "fileUris": ["https://github.com/MyProject/Archive/MyPythonScript.py"], "commandToExecute": "python MyPythonScript.py" } ``` No need to provide protected settings. ### 3.3 Run inline scripts * Public configuration ```json "commandToExecute": "echo Hello" "commandToExecute": "python -c \"print 1.4\"" ``` ### 3.4 Run scripts with unchanged configurations Running scripts with the exactly same configurations is unaccepted in current design. If you need to run scripts repeatly, you can add a timestamp. * Public configuration ```json { "fileUris": [""], "commandToExecute": "", "timestamp": 123456789 } ``` ### 3.5 Run scripts with sensitive data * Public configuration ```json { "fileUris": ["https://github.com/MyProject/Archive/MyPythonScript.py"] } ``` * Protected configuration ```json { "commandToExecute": "python MyPythonScript.py " } ``` ## Supported Linux Distributions - CentOS 6.5 and higher - Debian 8 and higher - Debian 8.7 does not ship with Python2 in the latest images, which breaks CustomScriptForLinux. - FreeBSD - OpenSUSE 13.1 and higher - Oracle Linux 6.4 and higher - SUSE Linux Enterprise Server 11 SP3 and higher - Ubuntu 12.04 and higher ## Debug * The status of the extension is reported back to Azure so that user can see the status on Azure Portal * All the execution output and error of the scripts are logged into the download directory of the scripts `/var/lib/waagent//download//`, and the tail of the output is logged into the log directory specified in HandlerEnvironment.json and reported back to Azure * The operation log of the extension is `/var/log/azure///extension.log` file. [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: CustomScript/azure/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import base64 import hashlib import hmac import sys import types import warnings import inspect if sys.version_info < (3,): from urllib2 import quote as url_quote from urllib2 import unquote as url_unquote _strtype = basestring else: from urllib.parse import quote as url_quote from urllib.parse import unquote as url_unquote _strtype = str from datetime import datetime from xml.dom import minidom from xml.sax.saxutils import escape as xml_escape #-------------------------------------------------------------------------- # constants __author__ = 'Microsoft Corp. ' __version__ = '0.8.4' # Live ServiceClient URLs BLOB_SERVICE_HOST_BASE = '.blob.core.windows.net' QUEUE_SERVICE_HOST_BASE = '.queue.core.windows.net' TABLE_SERVICE_HOST_BASE = '.table.core.windows.net' SERVICE_BUS_HOST_BASE = '.servicebus.windows.net' MANAGEMENT_HOST = 'management.core.windows.net' # Development ServiceClient URLs DEV_BLOB_HOST = '127.0.0.1:10000' DEV_QUEUE_HOST = '127.0.0.1:10001' DEV_TABLE_HOST = '127.0.0.1:10002' # Default credentials for Development Storage Service DEV_ACCOUNT_NAME = 'devstoreaccount1' DEV_ACCOUNT_KEY = 'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==' # All of our error messages _ERROR_CANNOT_FIND_PARTITION_KEY = 'Cannot find partition key in request.' _ERROR_CANNOT_FIND_ROW_KEY = 'Cannot find row key in request.' _ERROR_INCORRECT_TABLE_IN_BATCH = \ 'Table should be the same in a batch operations' _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH = \ 'Partition Key should be the same in a batch operations' _ERROR_DUPLICATE_ROW_KEY_IN_BATCH = \ 'Row Keys should not be the same in a batch operations' _ERROR_BATCH_COMMIT_FAIL = 'Batch Commit Fail' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE = \ 'Message is not peek locked and cannot be deleted.' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK = \ 'Message is not peek locked and cannot be unlocked.' _ERROR_QUEUE_NOT_FOUND = 'Queue was not found' _ERROR_TOPIC_NOT_FOUND = 'Topic was not found' _ERROR_CONFLICT = 'Conflict ({0})' _ERROR_NOT_FOUND = 'Not found ({0})' _ERROR_UNKNOWN = 'Unknown error ({0})' _ERROR_SERVICEBUS_MISSING_INFO = \ 'You need to provide servicebus namespace, access key and Issuer' _ERROR_STORAGE_MISSING_INFO = \ 'You need to provide both account name and access key' _ERROR_ACCESS_POLICY = \ 'share_access_policy must be either SignedIdentifier or AccessPolicy ' + \ 'instance' _WARNING_VALUE_SHOULD_BE_BYTES = \ 'Warning: {0} must be bytes data type. It will be converted ' + \ 'automatically, with utf-8 text encoding.' _ERROR_VALUE_SHOULD_BE_BYTES = '{0} should be of type bytes.' _ERROR_VALUE_NONE = '{0} should not be None.' _ERROR_VALUE_NEGATIVE = '{0} should not be negative.' _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY = \ 'Cannot serialize the specified value ({0}) to an entity. Please use ' + \ 'an EntityProperty (which can specify custom types), int, str, bool, ' + \ 'or datetime.' _ERROR_PAGE_BLOB_SIZE_ALIGNMENT = \ 'Invalid page blob size: {0}. ' + \ 'The size must be aligned to a 512-byte boundary.' _USER_AGENT_STRING = 'pyazure/' + __version__ METADATA_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices/metadata' class WindowsAzureData(object): ''' This is the base of data class. It is only used to check whether it is instance or not. ''' pass class WindowsAzureError(Exception): ''' WindowsAzure Excpetion base class. ''' def __init__(self, message): super(WindowsAzureError, self).__init__(message) class WindowsAzureConflictError(WindowsAzureError): '''Indicates that the resource could not be created because it already exists''' def __init__(self, message): super(WindowsAzureConflictError, self).__init__(message) class WindowsAzureMissingResourceError(WindowsAzureError): '''Indicates that a request for a request for a resource (queue, table, container, etc...) failed because the specified resource does not exist''' def __init__(self, message): super(WindowsAzureMissingResourceError, self).__init__(message) class WindowsAzureBatchOperationError(WindowsAzureError): '''Indicates that a batch operation failed''' def __init__(self, message, code): super(WindowsAzureBatchOperationError, self).__init__(message) self.code = code class Feed(object): pass class _Base64String(str): pass class HeaderDict(dict): def __getitem__(self, index): return super(HeaderDict, self).__getitem__(index.lower()) def _encode_base64(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') encoded = base64.b64encode(data) return encoded.decode('utf-8') def _decode_base64_to_bytes(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') return base64.b64decode(data) def _decode_base64_to_text(data): decoded_bytes = _decode_base64_to_bytes(data) return decoded_bytes.decode('utf-8') def _get_readable_id(id_name, id_prefix_to_skip): """simplified an id to be more friendly for us people""" # id_name is in the form 'https://namespace.host.suffix/name' # where name may contain a forward slash! pos = id_name.find('//') if pos != -1: pos += 2 if id_prefix_to_skip: pos = id_name.find(id_prefix_to_skip, pos) if pos != -1: pos += len(id_prefix_to_skip) pos = id_name.find('/', pos) if pos != -1: return id_name[pos + 1:] return id_name def _get_entry_properties_from_node(entry, include_id, id_prefix_to_skip=None, use_title_as_id=False): ''' get properties from entry xml ''' properties = {} etag = entry.getAttributeNS(METADATA_NS, 'etag') if etag: properties['etag'] = etag for updated in _get_child_nodes(entry, 'updated'): properties['updated'] = updated.firstChild.nodeValue for name in _get_children_from_path(entry, 'author', 'name'): if name.firstChild is not None: properties['author'] = name.firstChild.nodeValue if include_id: if use_title_as_id: for title in _get_child_nodes(entry, 'title'): properties['name'] = title.firstChild.nodeValue else: for id in _get_child_nodes(entry, 'id'): properties['name'] = _get_readable_id( id.firstChild.nodeValue, id_prefix_to_skip) return properties def _get_entry_properties(xmlstr, include_id, id_prefix_to_skip=None): ''' get properties from entry xml ''' xmldoc = minidom.parseString(xmlstr) properties = {} for entry in _get_child_nodes(xmldoc, 'entry'): properties.update(_get_entry_properties_from_node(entry, include_id, id_prefix_to_skip)) return properties def _get_first_child_node_value(parent_node, node_name): xml_attrs = _get_child_nodes(parent_node, node_name) if xml_attrs: xml_attr = xml_attrs[0] if xml_attr.firstChild: value = xml_attr.firstChild.nodeValue return value def _get_child_nodes(node, tagName): return [childNode for childNode in node.getElementsByTagName(tagName) if childNode.parentNode == node] def _get_children_from_path(node, *path): '''descends through a hierarchy of nodes returning the list of children at the inner most level. Only returns children who share a common parent, not cousins.''' cur = node for index, child in enumerate(path): if isinstance(child, _strtype): next = _get_child_nodes(cur, child) else: next = _get_child_nodesNS(cur, *child) if index == len(path) - 1: return next elif not next: break cur = next[0] return [] def _get_child_nodesNS(node, ns, tagName): return [childNode for childNode in node.getElementsByTagNameNS(ns, tagName) if childNode.parentNode == node] def _create_entry(entry_body): ''' Adds common part of entry to a given entry body and return the whole xml. ''' updated_str = datetime.utcnow().isoformat() if datetime.utcnow().utcoffset() is None: updated_str += '+00:00' entry_start = ''' <updated>{updated}</updated><author><name /></author><id /> <content type="application/xml"> {body}</content></entry>''' return entry_start.format(updated=updated_str, body=entry_body) def _to_datetime(strtime): return datetime.strptime(strtime, "%Y-%m-%dT%H:%M:%S.%f") _KNOWN_SERIALIZATION_XFORMS = { 'include_apis': 'IncludeAPIs', 'message_id': 'MessageId', 'content_md5': 'Content-MD5', 'last_modified': 'Last-Modified', 'cache_control': 'Cache-Control', 'account_admin_live_email_id': 'AccountAdminLiveEmailId', 'service_admin_live_email_id': 'ServiceAdminLiveEmailId', 'subscription_id': 'SubscriptionID', 'fqdn': 'FQDN', 'private_id': 'PrivateID', 'os_virtual_hard_disk': 'OSVirtualHardDisk', 'logical_disk_size_in_gb': 'LogicalDiskSizeInGB', 'logical_size_in_gb': 'LogicalSizeInGB', 'os': 'OS', 'persistent_vm_downtime_info': 'PersistentVMDowntimeInfo', 'copy_id': 'CopyId', } def _get_serialization_name(element_name): """converts a Python name into a serializable name""" known = _KNOWN_SERIALIZATION_XFORMS.get(element_name) if known is not None: return known if element_name.startswith('x_ms_'): return element_name.replace('_', '-') if element_name.endswith('_id'): element_name = element_name.replace('_id', 'ID') for name in ['content_', 'last_modified', 'if_', 'cache_control']: if element_name.startswith(name): element_name = element_name.replace('_', '-_') return ''.join(name.capitalize() for name in element_name.split('_')) if sys.version_info < (3,): _unicode_type = unicode def _str(value): if isinstance(value, unicode): return value.encode('utf-8') return str(value) else: _str = str _unicode_type = str def _str_or_none(value): if value is None: return None return _str(value) def _int_or_none(value): if value is None: return None return str(int(value)) def _bool_or_none(value): if value is None: return None if isinstance(value, bool): if value: return 'true' else: return 'false' return str(value) def _convert_class_to_xml(source, xml_prefix=True): if source is None: return '' xmlstr = '' if xml_prefix: xmlstr = '<?xml version="1.0" encoding="utf-8"?>' if isinstance(source, list): for value in source: xmlstr += _convert_class_to_xml(value, False) elif isinstance(source, WindowsAzureData): class_name = source.__class__.__name__ xmlstr += '<' + class_name + '>' for name, value in vars(source).items(): if value is not None: if isinstance(value, list) or \ isinstance(value, WindowsAzureData): xmlstr += _convert_class_to_xml(value, False) else: xmlstr += ('<' + _get_serialization_name(name) + '>' + xml_escape(str(value)) + '</' + _get_serialization_name(name) + '>') xmlstr += '</' + class_name + '>' return xmlstr def _find_namespaces_from_child(parent, child, namespaces): """Recursively searches from the parent to the child, gathering all the applicable namespaces along the way""" for cur_child in parent.childNodes: if cur_child is child: return True if _find_namespaces_from_child(cur_child, child, namespaces): # we are the parent node for key in cur_child.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': namespaces[key] = cur_child.attributes[key] break return False def _find_namespaces(parent, child): res = {} for key in parent.documentElement.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': res[key] = parent.documentElement.attributes[key] _find_namespaces_from_child(parent, child, res) return res def _clone_node_with_namespaces(node_to_clone, original_doc): clone = node_to_clone.cloneNode(True) for key, value in _find_namespaces(original_doc, node_to_clone).items(): clone.attributes[key] = value return clone def _convert_response_to_feeds(response, convert_callback): if response is None: return None feeds = _list_of(Feed) x_ms_continuation = HeaderDict() for name, value in response.headers: if 'x-ms-continuation' in name: x_ms_continuation[name[len('x-ms-continuation') + 1:]] = value if x_ms_continuation: setattr(feeds, 'x_ms_continuation', x_ms_continuation) xmldoc = minidom.parseString(response.body) xml_entries = _get_children_from_path(xmldoc, 'feed', 'entry') if not xml_entries: # in some cases, response contains only entry but no feed xml_entries = _get_children_from_path(xmldoc, 'entry') if inspect.isclass(convert_callback) and issubclass(convert_callback, WindowsAzureData): for xml_entry in xml_entries: return_obj = convert_callback() for node in _get_children_from_path(xml_entry, 'content', convert_callback.__name__): _fill_data_to_return_object(node, return_obj) for name, value in _get_entry_properties_from_node(xml_entry, include_id=True, use_title_as_id=True).items(): setattr(return_obj, name, value) feeds.append(return_obj) else: for xml_entry in xml_entries: new_node = _clone_node_with_namespaces(xml_entry, xmldoc) feeds.append(convert_callback(new_node.toxml('utf-8'))) return feeds def _validate_type_bytes(param_name, param): if not isinstance(param, bytes): raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _validate_not_none(param_name, param): if param is None: raise TypeError(_ERROR_VALUE_NONE.format(param_name)) def _fill_list_of(xmldoc, element_type, xml_element_name): xmlelements = _get_child_nodes(xmldoc, xml_element_name) return [_parse_response_body_from_xml_node(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_scalar_list_of(xmldoc, element_type, parent_xml_element_name, xml_element_name): '''Converts an xml fragment into a list of scalar types. The parent xml element contains a flat list of xml elements which are converted into the specified scalar type and added to the list. Example: xmldoc= <Endpoints> <Endpoint>http://{storage-service-name}.blob.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.queue.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.table.core.windows.net/</Endpoint> </Endpoints> element_type=str parent_xml_element_name='Endpoints' xml_element_name='Endpoint' ''' xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], xml_element_name) return [_get_node_value(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_dict(xmldoc, element_name): xmlelements = _get_child_nodes(xmldoc, element_name) if xmlelements: return_obj = {} for child in xmlelements[0].childNodes: if child.firstChild: return_obj[child.nodeName] = child.firstChild.nodeValue return return_obj def _fill_dict_of(xmldoc, parent_xml_element_name, pair_xml_element_name, key_xml_element_name, value_xml_element_name): '''Converts an xml fragment into a dictionary. The parent xml element contains a list of xml elements where each element has a child element for the key, and another for the value. Example: xmldoc= <ExtendedProperties> <ExtendedProperty> <Name>Ext1</Name> <Value>Val1</Value> </ExtendedProperty> <ExtendedProperty> <Name>Ext2</Name> <Value>Val2</Value> </ExtendedProperty> </ExtendedProperties> element_type=str parent_xml_element_name='ExtendedProperties' pair_xml_element_name='ExtendedProperty' key_xml_element_name='Name' value_xml_element_name='Value' ''' return_obj = {} xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], pair_xml_element_name) for pair in xmlelements: keys = _get_child_nodes(pair, key_xml_element_name) values = _get_child_nodes(pair, value_xml_element_name) if keys and values: key = keys[0].firstChild.nodeValue value = values[0].firstChild.nodeValue return_obj[key] = value return return_obj def _fill_instance_child(xmldoc, element_name, return_type): '''Converts a child of the current dom element to the specified type. ''' xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements: return None return_obj = return_type() _fill_data_to_return_object(xmlelements[0], return_obj) return return_obj def _fill_instance_element(element, return_type): """Converts a DOM element into the specified object""" return _parse_response_body_from_xml_node(element, return_type) def _fill_data_minidom(xmldoc, element_name, data_member): xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements or not xmlelements[0].childNodes: return None value = xmlelements[0].firstChild.nodeValue if data_member is None: return value elif isinstance(data_member, datetime): return _to_datetime(value) elif type(data_member) is bool: return value.lower() != 'false' else: return type(data_member)(value) def _get_node_value(xmlelement, data_type): value = xmlelement.firstChild.nodeValue if data_type is datetime: return _to_datetime(value) elif data_type is bool: return value.lower() != 'false' else: return data_type(value) def _get_request_body_bytes_only(param_name, param_value): '''Validates the request body passed in and converts it to bytes if our policy allows it.''' if param_value is None: return b'' if isinstance(param_value, bytes): return param_value # Previous versions of the SDK allowed data types other than bytes to be # passed in, and they would be auto-converted to bytes. We preserve this # behavior when running under 2.7, but issue a warning. # Python 3 support is new, so we reject anything that's not bytes. if sys.version_info < (3,): warnings.warn(_WARNING_VALUE_SHOULD_BE_BYTES.format(param_name)) return _get_request_body(param_value) raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _get_request_body(request_body): '''Converts an object into a request body. If it's None we'll return an empty string, if it's one of our objects it'll convert it to XML and return it. Otherwise we just use the object directly''' if request_body is None: return b'' if isinstance(request_body, WindowsAzureData): request_body = _convert_class_to_xml(request_body) if isinstance(request_body, bytes): return request_body if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') request_body = str(request_body) if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') return request_body def _parse_enum_results_list(response, return_type, resp_type, item_type): """resp_body is the XML we received resp_type is a string, such as Containers, return_type is the type we're constructing, such as ContainerEnumResults item_type is the type object of the item to be created, such as Container This function then returns a ContainerEnumResults object with the containers member populated with the results. """ # parsing something like: # <EnumerationResults ... > # <Queues> # <Queue> # <Something /> # <SomethingElse /> # </Queue> # </Queues> # </EnumerationResults> respbody = response.body return_obj = return_type() doc = minidom.parseString(respbody) items = [] for enum_results in _get_child_nodes(doc, 'EnumerationResults'): # path is something like Queues, Queue for child in _get_children_from_path(enum_results, resp_type, resp_type[:-1]): items.append(_fill_instance_element(child, item_type)) for name, value in vars(return_obj).items(): # queues, Queues, this is the list its self which we populated # above if name == resp_type.lower(): # the list its self. continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) setattr(return_obj, resp_type.lower(), items) return return_obj def _parse_simple_list(response, type, item_type, list_name): respbody = response.body res = type() res_items = [] doc = minidom.parseString(respbody) type_name = type.__name__ item_name = item_type.__name__ for item in _get_children_from_path(doc, type_name, item_name): res_items.append(_fill_instance_element(item, item_type)) setattr(res, list_name, res_items) return res def _parse_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_xml_text(response.body, return_type) def _parse_service_resources_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_service_resources_xml_text(response.body, return_type) def _fill_data_to_return_object(node, return_obj): members = dict(vars(return_obj)) for name, value in members.items(): if isinstance(value, _list_of): setattr(return_obj, name, _fill_list_of(node, value.list_type, value.xml_element_name)) elif isinstance(value, _scalar_list_of): setattr(return_obj, name, _fill_scalar_list_of(node, value.list_type, _get_serialization_name(name), value.xml_element_name)) elif isinstance(value, _dict_of): setattr(return_obj, name, _fill_dict_of(node, _get_serialization_name(name), value.pair_xml_element_name, value.key_xml_element_name, value.value_xml_element_name)) elif isinstance(value, _xml_attribute): real_value = None if node.hasAttribute(value.xml_element_name): real_value = node.getAttribute(value.xml_element_name) if real_value is not None: setattr(return_obj, name, real_value) elif isinstance(value, WindowsAzureData): setattr(return_obj, name, _fill_instance_child(node, name, value.__class__)) elif isinstance(value, dict): setattr(return_obj, name, _fill_dict(node, _get_serialization_name(name))) elif isinstance(value, _Base64String): value = _fill_data_minidom(node, name, '') if value is not None: value = _decode_base64_to_text(value) # always set the attribute, so we don't end up returning an object # with type _Base64String setattr(return_obj, name, value) else: value = _fill_data_minidom(node, name, value) if value is not None: setattr(return_obj, name, value) def _parse_response_body_from_xml_node(node, return_type): ''' parse the xml and fill all the data into a class of return_type ''' return_obj = return_type() _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = return_type() xml_name = return_type._xml_name if hasattr(return_type, '_xml_name') else return_type.__name__ for node in _get_child_nodes(doc, xml_name): _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_service_resources_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = _list_of(return_type) for node in _get_children_from_path(doc, "ServiceResources", "ServiceResource"): local_obj = return_type() _fill_data_to_return_object(node, local_obj) return_obj.append(local_obj) return return_obj class _dict_of(dict): """a dict which carries with it the xml element names for key,val. Used for deserializaion and construction of the lists""" def __init__(self, pair_xml_element_name, key_xml_element_name, value_xml_element_name): self.pair_xml_element_name = pair_xml_element_name self.key_xml_element_name = key_xml_element_name self.value_xml_element_name = value_xml_element_name super(_dict_of, self).__init__() class _list_of(list): """a list which carries with it the type that's expected to go in it. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name=None): self.list_type = list_type if xml_element_name is None: self.xml_element_name = list_type.__name__ else: self.xml_element_name = xml_element_name super(_list_of, self).__init__() class _scalar_list_of(list): """a list of scalar types which carries with it the type that's expected to go in it along with its xml element name. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name): self.list_type = list_type self.xml_element_name = xml_element_name super(_scalar_list_of, self).__init__() class _xml_attribute: """a accessor to XML attributes expected to go in it along with its xml element name. Used for deserialization and construction""" def __init__(self, xml_element_name): self.xml_element_name = xml_element_name def _update_request_uri_query_local_storage(request, use_local_storage): ''' create correct uri and query for the request ''' uri, query = _update_request_uri_query(request) if use_local_storage: return '/' + DEV_ACCOUNT_NAME + uri, query return uri, query def _update_request_uri_query(request): '''pulls the query string out of the URI and moves it into the query portion of the request object. If there are already query parameters on the request the parameters in the URI will appear after the existing parameters''' if '?' in request.path: request.path, _, query_string = request.path.partition('?') if query_string: query_params = query_string.split('&') for query in query_params: if '=' in query: name, _, value = query.partition('=') request.query.append((name, value)) request.path = url_quote(request.path, '/()$=\',') # add encoded queries to request.path. if request.query: request.path += '?' for name, value in request.query: if value is not None: request.path += name + '=' + url_quote(value, '/()$=\',') + '&' request.path = request.path[:-1] return request.path, request.query def _dont_fail_on_exist(error): ''' don't throw exception if the resource exists. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureConflictError): return False else: raise error def _dont_fail_not_exist(error): ''' don't throw exception if the resource doesn't exist. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureMissingResourceError): return False else: raise error def _general_error_handler(http_error): ''' Simple error handler for azure.''' if http_error.status == 409: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(str(http_error))) elif http_error.status == 404: raise WindowsAzureMissingResourceError( _ERROR_NOT_FOUND.format(str(http_error))) else: if http_error.respbody is not None: raise WindowsAzureError( _ERROR_UNKNOWN.format(str(http_error)) + '\n' + \ http_error.respbody.decode('utf-8')) else: raise WindowsAzureError(_ERROR_UNKNOWN.format(str(http_error))) def _parse_response_for_dict(response): ''' Extracts name-values from response header. Filter out the standard http headers.''' if response is None: return None http_headers = ['server', 'date', 'location', 'host', 'via', 'proxy-connection', 'connection'] return_dict = HeaderDict() if response.headers: for name, value in response.headers: if not name.lower() in http_headers: return_dict[name] = value return return_dict def _parse_response_for_dict_prefix(response, prefixes): ''' Extracts name-values for names starting with prefix from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): for prefix_value in prefixes: if name.lower().startswith(prefix_value.lower()): return_dict[name] = value break return return_dict else: return None def _parse_response_for_dict_filter(response, filter): ''' Extracts name-values for names in filter from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): if name.lower() in filter: return_dict[name] = value return return_dict else: return None def _sign_string(key, string_to_sign, key_is_base64=True): if key_is_base64: key = _decode_base64_to_bytes(key) else: if isinstance(key, _unicode_type): key = key.encode('utf-8') if isinstance(string_to_sign, _unicode_type): string_to_sign = string_to_sign.encode('utf-8') signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = _encode_base64(digest) return encoded_digest ================================================ FILE: CustomScript/azure/azure.pyproj ================================================ <?xml version="1.0" encoding="utf-8"?> <Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0"> <PropertyGroup> <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> <SchemaVersion>2.0</SchemaVersion> <ProjectGuid>{25b2c65a-0553-4452-8907-8b5b17544e68}</ProjectGuid> <ProjectHome> </ProjectHome> <StartupFile>storage\blobservice.py</StartupFile> <SearchPath>..</SearchPath> <WorkingDirectory>.</WorkingDirectory> <OutputPath>.</OutputPath> <Name>azure</Name> <RootNamespace>azure</RootNamespace> <IsWindowsApplication>False</IsWindowsApplication> <LaunchProvider>Standard Python launcher</LaunchProvider> <CommandLineArguments /> <InterpreterPath /> <InterpreterArguments /> <InterpreterId>{9a7a9026-48c1-4688-9d5d-e5699d47d074}</InterpreterId> <InterpreterVersion>3.4</InterpreterVersion> <SccProjectName>SAK</SccProjectName> <SccProvider>SAK</SccProvider> <SccAuxPath>SAK</SccAuxPath> <SccLocalPath>SAK</SccLocalPath> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Debug' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Release' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <ItemGroup> <Compile Include="http\batchclient.py" /> <Compile Include="http\httpclient.py" /> <Compile Include="http\winhttp.py" /> <Compile Include="http\__init__.py" /> <Compile Include="servicemanagement\servicebusmanagementservice.py" /> <Compile Include="servicemanagement\servicemanagementclient.py" /> <Compile Include="servicemanagement\servicemanagementservice.py" /> <Compile Include="servicemanagement\sqldatabasemanagementservice.py" /> <Compile Include="servicemanagement\websitemanagementservice.py" /> <Compile Include="servicemanagement\__init__.py" /> <Compile Include="servicebus\servicebusservice.py" /> <Compile Include="storage\blobservice.py" /> <Compile Include="storage\queueservice.py" /> <Compile Include="storage\cloudstorageaccount.py" /> <Compile Include="storage\tableservice.py" /> <Compile Include="storage\sharedaccesssignature.py" /> <Compile Include="__init__.py" /> <Compile Include="servicebus\__init__.py" /> <Compile Include="storage\storageclient.py" /> <Compile Include="storage\__init__.py" /> </ItemGroup> <ItemGroup> <Folder Include="http" /> <Folder Include="servicemanagement" /> <Folder Include="servicebus\" /> <Folder Include="storage" /> </ItemGroup> <ItemGroup> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.6" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.7" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.3" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.4" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\2.7" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.3" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.4" /> </ItemGroup> <PropertyGroup> <VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion> <VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath> <PtvsTargetsFile>$(VSToolsPath)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile> </PropertyGroup> <Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" /> <Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" /> </Project> ================================================ FILE: CustomScript/azure/http/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- HTTP_RESPONSE_NO_CONTENT = 204 class HTTPError(Exception): ''' HTTP Exception when response status code >= 300 ''' def __init__(self, status, message, respheader, respbody): '''Creates a new HTTPError with the specified status, message, response headers and body''' self.status = status self.respheader = respheader self.respbody = respbody Exception.__init__(self, message) class HTTPResponse(object): """Represents a response from an HTTP request. An HTTPResponse has the following attributes: status: the status code of the response message: the message headers: the returned headers, as a list of (name, value) pairs body: the body of the response """ def __init__(self, status, message, headers, body): self.status = status self.message = message self.headers = headers self.body = body class HTTPRequest(object): '''Represents an HTTP Request. An HTTP Request consists of the following attributes: host: the host name to connect to method: the method to use to connect (string such as GET, POST, PUT, etc.) path: the uri fragment query: query parameters specified as a list of (name, value) pairs headers: header values specified as (name, value) pairs body: the body of the request. protocol_override: specify to use this protocol instead of the global one stored in _HTTPClient. ''' def __init__(self): self.host = '' self.method = '' self.path = '' self.query = [] # list of (name, value) self.headers = [] # list of (header name, header value) self.body = '' self.protocol_override = None ================================================ FILE: CustomScript/azure/http/batchclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import uuid from azure import ( _update_request_uri_query, WindowsAzureError, WindowsAzureBatchOperationError, _get_children_from_path, url_unquote, _ERROR_CANNOT_FIND_PARTITION_KEY, _ERROR_CANNOT_FIND_ROW_KEY, _ERROR_INCORRECT_TABLE_IN_BATCH, _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH, _ERROR_DUPLICATE_ROW_KEY_IN_BATCH, _ERROR_BATCH_COMMIT_FAIL, ) from azure.http import HTTPError, HTTPRequest, HTTPResponse from azure.http.httpclient import _HTTPClient from azure.storage import ( _update_storage_table_header, METADATA_NS, _sign_storage_table_request, ) from xml.dom import minidom _DATASERVICES_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices' if sys.version_info < (3,): def _new_boundary(): return str(uuid.uuid1()) else: def _new_boundary(): return str(uuid.uuid1()).encode('utf-8') class _BatchClient(_HTTPClient): ''' This is the class that is used for batch operation for storage table service. It only supports one changeset. ''' def __init__(self, service_instance, account_key, account_name, protocol='http'): _HTTPClient.__init__(self, service_instance, account_name=account_name, account_key=account_key, protocol=protocol) self.is_batch = False self.batch_requests = [] self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] def get_request_table(self, request): ''' Extracts table name from request.uri. The request.uri has either "/mytable(...)" or "/mytable" format. request: the request to insert, update or delete entity ''' if '(' in request.path: pos = request.path.find('(') return request.path[1:pos] else: return request.path[1:] def get_request_partition_key(self, request): ''' Extracts PartitionKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the PartitionKey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) part_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'PartitionKey')) if not part_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return part_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('PartitionKey=\'') pos2 = uri.find('\',', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return uri[pos1 + len('PartitionKey=\''):pos2] def get_request_row_key(self, request): ''' Extracts RowKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the Rowkey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) row_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'RowKey')) if not row_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) return row_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('RowKey=\'') pos2 = uri.find('\')', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) row_key = uri[pos1 + len('RowKey=\''):pos2] return row_key def validate_request_table(self, request): ''' Validates that all requests have the same table name. Set the table name if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_table: if self.get_request_table(request) != self.batch_table: raise WindowsAzureError(_ERROR_INCORRECT_TABLE_IN_BATCH) else: self.batch_table = self.get_request_table(request) def validate_request_partition_key(self, request): ''' Validates that all requests have the same PartitiionKey. Set the PartitionKey if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_partition_key: if self.get_request_partition_key(request) != \ self.batch_partition_key: raise WindowsAzureError(_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH) else: self.batch_partition_key = self.get_request_partition_key(request) def validate_request_row_key(self, request): ''' Validates that all requests have the different RowKey and adds RowKey to existing RowKey list. request: the request to insert, update or delete entity ''' if self.batch_row_keys: if self.get_request_row_key(request) in self.batch_row_keys: raise WindowsAzureError(_ERROR_DUPLICATE_ROW_KEY_IN_BATCH) else: self.batch_row_keys.append(self.get_request_row_key(request)) def begin_batch(self): ''' Starts the batch operation. Intializes the batch variables is_batch: batch operation flag. batch_table: the table name of the batch operation batch_partition_key: the PartitionKey of the batch requests. batch_row_keys: the RowKey list of adding requests. batch_requests: the list of the requests. ''' self.is_batch = True self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] self.batch_requests = [] def insert_request_to_batch(self, request): ''' Adds request to batch operation. request: the request to insert, update or delete entity ''' self.validate_request_table(request) self.validate_request_partition_key(request) self.validate_request_row_key(request) self.batch_requests.append(request) def commit_batch(self): ''' Resets batch flag and commits the batch requests. ''' if self.is_batch: self.is_batch = False self.commit_batch_requests() def commit_batch_requests(self): ''' Commits the batch requests. ''' batch_boundary = b'batch_' + _new_boundary() changeset_boundary = b'changeset_' + _new_boundary() # Commits batch only the requests list is not empty. if self.batch_requests: request = HTTPRequest() request.method = 'POST' request.host = self.batch_requests[0].host request.path = '/$batch' request.headers = [ ('Content-Type', 'multipart/mixed; boundary=' + \ batch_boundary.decode('utf-8')), ('Accept', 'application/atom+xml,application/xml'), ('Accept-Charset', 'UTF-8')] request.body = b'--' + batch_boundary + b'\n' request.body += b'Content-Type: multipart/mixed; boundary=' request.body += changeset_boundary + b'\n\n' content_id = 1 # Adds each request body to the POST data. for batch_request in self.batch_requests: request.body += b'--' + changeset_boundary + b'\n' request.body += b'Content-Type: application/http\n' request.body += b'Content-Transfer-Encoding: binary\n\n' request.body += batch_request.method.encode('utf-8') request.body += b' http://' request.body += batch_request.host.encode('utf-8') request.body += batch_request.path.encode('utf-8') request.body += b' HTTP/1.1\n' request.body += b'Content-ID: ' request.body += str(content_id).encode('utf-8') + b'\n' content_id += 1 # Add different headers for different type requests. if not batch_request.method == 'DELETE': request.body += \ b'Content-Type: application/atom+xml;type=entry\n' for name, value in batch_request.headers: if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n' break request.body += b'Content-Length: ' request.body += str(len(batch_request.body)).encode('utf-8') request.body += b'\n\n' request.body += batch_request.body + b'\n' else: for name, value in batch_request.headers: # If-Match should be already included in # batch_request.headers, but in case it is missing, # just add it. if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n\n' break else: request.body += b'If-Match: *\n\n' request.body += b'--' + changeset_boundary + b'--' + b'\n' request.body += b'--' + batch_boundary + b'--' request.path, request.query = _update_request_uri_query(request) request.headers = _update_storage_table_header(request) auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) # Submit the whole request as batch request. response = self.perform_request(request) if response.status >= 300: raise HTTPError(response.status, _ERROR_BATCH_COMMIT_FAIL, self.respheader, response.body) # http://www.odata.org/documentation/odata-version-2-0/batch-processing/ # The body of a ChangeSet response is either a response for all the # successfully processed change request within the ChangeSet, # formatted exactly as it would have appeared outside of a batch, # or a single response indicating a failure of the entire ChangeSet. responses = self._parse_batch_response(response.body) if responses and responses[0].status >= 300: self._report_batch_error(responses[0]) def cancel_batch(self): ''' Resets the batch flag. ''' self.is_batch = False def _parse_batch_response(self, body): parts = body.split(b'--changesetresponse_') responses = [] for part in parts: httpLocation = part.find(b'HTTP/') if httpLocation > 0: response = self._parse_batch_response_part(part[httpLocation:]) responses.append(response) return responses def _parse_batch_response_part(self, part): lines = part.splitlines(); # First line is the HTTP status/reason status, _, reason = lines[0].partition(b' ')[2].partition(b' ') # Followed by headers and body headers = [] body = b'' isBody = False for line in lines[1:]: if line == b'' and not isBody: isBody = True elif isBody: body += line else: headerName, _, headerVal = line.partition(b':') headers.append((headerName.lower(), headerVal)) return HTTPResponse(int(status), reason.strip(), headers, body) def _report_batch_error(self, response): xml = response.body.decode('utf-8') doc = minidom.parseString(xml) n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'code') code = n[0].firstChild.nodeValue if n and n[0].firstChild else '' n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'message') message = n[0].firstChild.nodeValue if n and n[0].firstChild else xml raise WindowsAzureBatchOperationError(message, code) ================================================ FILE: CustomScript/azure/http/httpclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import base64 import os import sys if sys.version_info < (3,): from httplib import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urlparse import urlparse else: from http.client import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urllib.parse import urlparse from azure.http import HTTPError, HTTPResponse from azure import _USER_AGENT_STRING, _update_request_uri_query class _HTTPClient(object): ''' Takes the request and sends it to cloud service and returns the response. ''' def __init__(self, service_instance, cert_file=None, account_name=None, account_key=None, protocol='https'): ''' service_instance: service client instance. cert_file: certificate file name/location. This is only used in hosted service management. account_name: the storage account. account_key: the storage account access key. ''' self.service_instance = service_instance self.status = None self.respheader = None self.message = None self.cert_file = cert_file self.account_name = account_name self.account_key = account_key self.protocol = protocol self.proxy_host = None self.proxy_port = None self.proxy_user = None self.proxy_password = None self.use_httplib = self.should_use_httplib() def should_use_httplib(self): if sys.platform.lower().startswith('win') and self.cert_file: # On Windows, auto-detect between Windows Store Certificate # (winhttp) and OpenSSL .pem certificate file (httplib). # # We used to only support certificates installed in the Windows # Certificate Store. # cert_file example: CURRENT_USER\my\CertificateName # # We now support using an OpenSSL .pem certificate file, # for a consistent experience across all platforms. # cert_file example: account\certificate.pem # # When using OpenSSL .pem certificate file on Windows, make sure # you are on CPython 2.7.4 or later. # If it's not an existing file on disk, then treat it as a path in # the Windows Certificate Store, which means we can't use httplib. if not os.path.isfile(self.cert_file): return False return True def set_proxy(self, host, port, user, password): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self.proxy_host = host self.proxy_port = port self.proxy_user = user self.proxy_password = password def get_uri(self, request): ''' Return the target uri for the request.''' protocol = request.protocol_override \ if request.protocol_override else self.protocol port = HTTP_PORT if protocol == 'http' else HTTPS_PORT return protocol + '://' + request.host + ':' + str(port) + request.path def get_connection(self, request): ''' Create connection for the request. ''' protocol = request.protocol_override \ if request.protocol_override else self.protocol target_host = request.host target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT if not self.use_httplib: import azure.http.winhttp connection = azure.http.winhttp._HTTPConnection( target_host, cert_file=self.cert_file, protocol=protocol) proxy_host = self.proxy_host proxy_port = self.proxy_port else: if ':' in target_host: target_host, _, target_port = target_host.rpartition(':') if self.proxy_host: proxy_host = target_host proxy_port = target_port host = self.proxy_host port = self.proxy_port else: host = target_host port = target_port if protocol == 'http': connection = HTTPConnection(host, int(port)) else: connection = HTTPSConnection( host, int(port), cert_file=self.cert_file) if self.proxy_host: headers = None if self.proxy_user and self.proxy_password: auth = base64.encodestring( "{0}:{1}".format(self.proxy_user, self.proxy_password)) headers = {'Proxy-Authorization': 'Basic {0}'.format(auth)} connection.set_tunnel(proxy_host, int(proxy_port), headers) return connection def send_request_headers(self, connection, request_headers): if self.use_httplib: if self.proxy_host: for i in connection._buffer: if i.startswith("Host: "): connection._buffer.remove(i) connection.putheader( 'Host', "{0}:{1}".format(connection._tunnel_host, connection._tunnel_port)) for name, value in request_headers: if value: connection.putheader(name, value) connection.putheader('User-Agent', _USER_AGENT_STRING) connection.endheaders() def send_request_body(self, connection, request_body): if request_body: assert isinstance(request_body, bytes) connection.send(request_body) elif (not isinstance(connection, HTTPSConnection) and not isinstance(connection, HTTPConnection)): connection.send(None) def perform_request(self, request): ''' Sends request to cloud service server and return the response. ''' connection = self.get_connection(request) try: connection.putrequest(request.method, request.path) if not self.use_httplib: if self.proxy_host and self.proxy_user: connection.set_proxy_credentials( self.proxy_user, self.proxy_password) self.send_request_headers(connection, request.headers) self.send_request_body(connection, request.body) resp = connection.getresponse() self.status = int(resp.status) self.message = resp.reason self.respheader = headers = resp.getheaders() # for consistency across platforms, make header names lowercase for i, value in enumerate(headers): headers[i] = (value[0].lower(), value[1]) respbody = None if resp.length is None: respbody = resp.read() elif resp.length > 0: respbody = resp.read(resp.length) response = HTTPResponse( int(resp.status), resp.reason, headers, respbody) if self.status == 307: new_url = urlparse(dict(headers)['location']) request.host = new_url.hostname request.path = new_url.path request.path, request.query = _update_request_uri_query(request) return self.perform_request(request) if self.status >= 300: raise HTTPError(self.status, self.message, self.respheader, respbody) return response finally: connection.close() ================================================ FILE: CustomScript/azure/http/winhttp.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from ctypes import ( c_void_p, c_long, c_ulong, c_longlong, c_ulonglong, c_short, c_ushort, c_wchar_p, c_byte, byref, Structure, Union, POINTER, WINFUNCTYPE, HRESULT, oledll, WinDLL, ) import ctypes import sys if sys.version_info >= (3,): def unicode(text): return text #------------------------------------------------------------------------------ # Constants that are used in COM operations VT_EMPTY = 0 VT_NULL = 1 VT_I2 = 2 VT_I4 = 3 VT_BSTR = 8 VT_BOOL = 11 VT_I1 = 16 VT_UI1 = 17 VT_UI2 = 18 VT_UI4 = 19 VT_I8 = 20 VT_UI8 = 21 VT_ARRAY = 8192 HTTPREQUEST_PROXYSETTING_PROXY = 2 HTTPREQUEST_SETCREDENTIALS_FOR_PROXY = 1 HTTPREQUEST_PROXY_SETTING = c_long HTTPREQUEST_SETCREDENTIALS_FLAGS = c_long #------------------------------------------------------------------------------ # Com related APIs that are used. _ole32 = oledll.ole32 _oleaut32 = WinDLL('oleaut32') _CLSIDFromString = _ole32.CLSIDFromString _CoInitialize = _ole32.CoInitialize _CoInitialize.argtypes = [c_void_p] _CoCreateInstance = _ole32.CoCreateInstance _SysAllocString = _oleaut32.SysAllocString _SysAllocString.restype = c_void_p _SysAllocString.argtypes = [c_wchar_p] _SysFreeString = _oleaut32.SysFreeString _SysFreeString.argtypes = [c_void_p] # SAFEARRAY* # SafeArrayCreateVector(_In_ VARTYPE vt,_In_ LONG lLbound,_In_ ULONG # cElements); _SafeArrayCreateVector = _oleaut32.SafeArrayCreateVector _SafeArrayCreateVector.restype = c_void_p _SafeArrayCreateVector.argtypes = [c_ushort, c_long, c_ulong] # HRESULT # SafeArrayAccessData(_In_ SAFEARRAY *psa, _Out_ void **ppvData); _SafeArrayAccessData = _oleaut32.SafeArrayAccessData _SafeArrayAccessData.argtypes = [c_void_p, POINTER(c_void_p)] # HRESULT # SafeArrayUnaccessData(_In_ SAFEARRAY *psa); _SafeArrayUnaccessData = _oleaut32.SafeArrayUnaccessData _SafeArrayUnaccessData.argtypes = [c_void_p] # HRESULT # SafeArrayGetUBound(_In_ SAFEARRAY *psa, _In_ UINT nDim, _Out_ LONG # *plUbound); _SafeArrayGetUBound = _oleaut32.SafeArrayGetUBound _SafeArrayGetUBound.argtypes = [c_void_p, c_ulong, POINTER(c_long)] #------------------------------------------------------------------------------ class BSTR(c_wchar_p): ''' BSTR class in python. ''' def __init__(self, value): super(BSTR, self).__init__(_SysAllocString(value)) def __del__(self): _SysFreeString(self) class VARIANT(Structure): ''' VARIANT structure in python. Does not match the definition in MSDN exactly & it is only mapping the used fields. Field names are also slighty different. ''' class _tagData(Union): class _tagRecord(Structure): _fields_ = [('pvoid', c_void_p), ('precord', c_void_p)] _fields_ = [('llval', c_longlong), ('ullval', c_ulonglong), ('lval', c_long), ('ulval', c_ulong), ('ival', c_short), ('boolval', c_ushort), ('bstrval', BSTR), ('parray', c_void_p), ('record', _tagRecord)] _fields_ = [('vt', c_ushort), ('wReserved1', c_ushort), ('wReserved2', c_ushort), ('wReserved3', c_ushort), ('vdata', _tagData)] @staticmethod def create_empty(): variant = VARIANT() variant.vt = VT_EMPTY variant.vdata.llval = 0 return variant @staticmethod def create_safearray_from_str(text): variant = VARIANT() variant.vt = VT_ARRAY | VT_UI1 length = len(text) variant.vdata.parray = _SafeArrayCreateVector(VT_UI1, 0, length) pvdata = c_void_p() _SafeArrayAccessData(variant.vdata.parray, byref(pvdata)) ctypes.memmove(pvdata, text, length) _SafeArrayUnaccessData(variant.vdata.parray) return variant @staticmethod def create_bstr_from_str(text): variant = VARIANT() variant.vt = VT_BSTR variant.vdata.bstrval = BSTR(text) return variant @staticmethod def create_bool_false(): variant = VARIANT() variant.vt = VT_BOOL variant.vdata.boolval = 0 return variant def is_safearray_of_bytes(self): return self.vt == VT_ARRAY | VT_UI1 def str_from_safearray(self): assert self.vt == VT_ARRAY | VT_UI1 pvdata = c_void_p() count = c_long() _SafeArrayGetUBound(self.vdata.parray, 1, byref(count)) count = c_long(count.value + 1) _SafeArrayAccessData(self.vdata.parray, byref(pvdata)) text = ctypes.string_at(pvdata, count) _SafeArrayUnaccessData(self.vdata.parray) return text def __del__(self): _VariantClear(self) # HRESULT VariantClear(_Inout_ VARIANTARG *pvarg); _VariantClear = _oleaut32.VariantClear _VariantClear.argtypes = [POINTER(VARIANT)] class GUID(Structure): ''' GUID structure in python. ''' _fields_ = [("data1", c_ulong), ("data2", c_ushort), ("data3", c_ushort), ("data4", c_byte * 8)] def __init__(self, name=None): if name is not None: _CLSIDFromString(unicode(name), byref(self)) class _WinHttpRequest(c_void_p): ''' Maps the Com API to Python class functions. Not all methods in IWinHttpWebRequest are mapped - only the methods we use. ''' _AddRef = WINFUNCTYPE(c_long) \ (1, 'AddRef') _Release = WINFUNCTYPE(c_long) \ (2, 'Release') _SetProxy = WINFUNCTYPE(HRESULT, HTTPREQUEST_PROXY_SETTING, VARIANT, VARIANT) \ (7, 'SetProxy') _SetCredentials = WINFUNCTYPE(HRESULT, BSTR, BSTR, HTTPREQUEST_SETCREDENTIALS_FLAGS) \ (8, 'SetCredentials') _Open = WINFUNCTYPE(HRESULT, BSTR, BSTR, VARIANT) \ (9, 'Open') _SetRequestHeader = WINFUNCTYPE(HRESULT, BSTR, BSTR) \ (10, 'SetRequestHeader') _GetResponseHeader = WINFUNCTYPE(HRESULT, BSTR, POINTER(c_void_p)) \ (11, 'GetResponseHeader') _GetAllResponseHeaders = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (12, 'GetAllResponseHeaders') _Send = WINFUNCTYPE(HRESULT, VARIANT) \ (13, 'Send') _Status = WINFUNCTYPE(HRESULT, POINTER(c_long)) \ (14, 'Status') _StatusText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (15, 'StatusText') _ResponseText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (16, 'ResponseText') _ResponseBody = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (17, 'ResponseBody') _ResponseStream = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (18, 'ResponseStream') _WaitForResponse = WINFUNCTYPE(HRESULT, VARIANT, POINTER(c_ushort)) \ (21, 'WaitForResponse') _Abort = WINFUNCTYPE(HRESULT) \ (22, 'Abort') _SetTimeouts = WINFUNCTYPE(HRESULT, c_long, c_long, c_long, c_long) \ (23, 'SetTimeouts') _SetClientCertificate = WINFUNCTYPE(HRESULT, BSTR) \ (24, 'SetClientCertificate') def open(self, method, url): ''' Opens the request. method: the request VERB 'GET', 'POST', etc. url: the url to connect ''' _WinHttpRequest._SetTimeouts(self, 0, 65000, 65000, 65000) flag = VARIANT.create_bool_false() _method = BSTR(method) _url = BSTR(url) _WinHttpRequest._Open(self, _method, _url, flag) def set_request_header(self, name, value): ''' Sets the request header. ''' _name = BSTR(name) _value = BSTR(value) _WinHttpRequest._SetRequestHeader(self, _name, _value) def get_all_response_headers(self): ''' Gets back all response headers. ''' bstr_headers = c_void_p() _WinHttpRequest._GetAllResponseHeaders(self, byref(bstr_headers)) bstr_headers = ctypes.cast(bstr_headers, c_wchar_p) headers = bstr_headers.value _SysFreeString(bstr_headers) return headers def send(self, request=None): ''' Sends the request body. ''' # Sends VT_EMPTY if it is GET, HEAD request. if request is None: var_empty = VARIANT.create_empty() _WinHttpRequest._Send(self, var_empty) else: # Sends request body as SAFEArray. _request = VARIANT.create_safearray_from_str(request) _WinHttpRequest._Send(self, _request) def status(self): ''' Gets status of response. ''' status = c_long() _WinHttpRequest._Status(self, byref(status)) return int(status.value) def status_text(self): ''' Gets status text of response. ''' bstr_status_text = c_void_p() _WinHttpRequest._StatusText(self, byref(bstr_status_text)) bstr_status_text = ctypes.cast(bstr_status_text, c_wchar_p) status_text = bstr_status_text.value _SysFreeString(bstr_status_text) return status_text def response_body(self): ''' Gets response body as a SAFEARRAY and converts the SAFEARRAY to str. If it is an xml file, it always contains 3 characters before <?xml, so we remove them. ''' var_respbody = VARIANT() _WinHttpRequest._ResponseBody(self, byref(var_respbody)) if var_respbody.is_safearray_of_bytes(): respbody = var_respbody.str_from_safearray() if respbody[3:].startswith(b'<?xml') and\ respbody.startswith(b'\xef\xbb\xbf'): respbody = respbody[3:] return respbody else: return '' def set_client_certificate(self, certificate): '''Sets client certificate for the request. ''' _certificate = BSTR(certificate) _WinHttpRequest._SetClientCertificate(self, _certificate) def set_tunnel(self, host, port): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling.''' url = host if port: url = url + u':' + port var_host = VARIANT.create_bstr_from_str(url) var_empty = VARIANT.create_empty() _WinHttpRequest._SetProxy( self, HTTPREQUEST_PROXYSETTING_PROXY, var_host, var_empty) def set_proxy_credentials(self, user, password): _WinHttpRequest._SetCredentials( self, BSTR(user), BSTR(password), HTTPREQUEST_SETCREDENTIALS_FOR_PROXY) def __del__(self): if self.value is not None: _WinHttpRequest._Release(self) class _Response(object): ''' Response class corresponding to the response returned from httplib HTTPConnection. ''' def __init__(self, _status, _status_text, _length, _headers, _respbody): self.status = _status self.reason = _status_text self.length = _length self.headers = _headers self.respbody = _respbody def getheaders(self): '''Returns response headers.''' return self.headers def read(self, _length): '''Returns resonse body. ''' return self.respbody[:_length] class _HTTPConnection(object): ''' Class corresponding to httplib HTTPConnection class. ''' def __init__(self, host, cert_file=None, key_file=None, protocol='http'): ''' initialize the IWinHttpWebRequest Com Object.''' self.host = unicode(host) self.cert_file = cert_file self._httprequest = _WinHttpRequest() self.protocol = protocol clsid = GUID('{2087C2F4-2CEF-4953-A8AB-66779B670495}') iid = GUID('{016FE2EC-B2C8-45F8-B23B-39E53A75396B}') _CoInitialize(None) _CoCreateInstance(byref(clsid), 0, 1, byref(iid), byref(self._httprequest)) def close(self): pass def set_tunnel(self, host, port=None, headers=None): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling. ''' self._httprequest.set_tunnel(unicode(host), unicode(str(port))) def set_proxy_credentials(self, user, password): self._httprequest.set_proxy_credentials( unicode(user), unicode(password)) def putrequest(self, method, uri): ''' Connects to host and sends the request. ''' protocol = unicode(self.protocol + '://') url = protocol + self.host + unicode(uri) self._httprequest.open(unicode(method), url) # sets certificate for the connection if cert_file is set. if self.cert_file is not None: self._httprequest.set_client_certificate(unicode(self.cert_file)) def putheader(self, name, value): ''' Sends the headers of request. ''' if sys.version_info < (3,): name = str(name).decode('utf-8') value = str(value).decode('utf-8') self._httprequest.set_request_header(name, value) def endheaders(self): ''' No operation. Exists only to provide the same interface of httplib HTTPConnection.''' pass def send(self, request_body): ''' Sends request body. ''' if not request_body: self._httprequest.send() else: self._httprequest.send(request_body) def getresponse(self): ''' Gets the response and generates the _Response object''' status = self._httprequest.status() status_text = self._httprequest.status_text() resp_headers = self._httprequest.get_all_response_headers() fixed_headers = [] for resp_header in resp_headers.split('\n'): if (resp_header.startswith('\t') or\ resp_header.startswith(' ')) and fixed_headers: # append to previous header fixed_headers[-1] += resp_header else: fixed_headers.append(resp_header) headers = [] for resp_header in fixed_headers: if ':' in resp_header: pos = resp_header.find(':') headers.append( (resp_header[:pos].lower(), resp_header[pos + 1:].strip())) body = self._httprequest.response_body() length = len(body) return _Response(status, status_text, length, headers, body) ================================================ FILE: CustomScript/azure/servicebus/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import json import sys from datetime import datetime from xml.dom import minidom from azure import ( WindowsAzureData, WindowsAzureError, xml_escape, _create_entry, _general_error_handler, _get_entry_properties, _get_child_nodes, _get_children_from_path, _get_first_child_node_value, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK, _ERROR_QUEUE_NOT_FOUND, _ERROR_TOPIC_NOT_FOUND, ) from azure.http import HTTPError # default rule name for subscription DEFAULT_RULE_NAME = '$Default' #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_SERVICEBUS_NAMESPACE = 'AZURE_SERVICEBUS_NAMESPACE' AZURE_SERVICEBUS_ACCESS_KEY = 'AZURE_SERVICEBUS_ACCESS_KEY' AZURE_SERVICEBUS_ISSUER = 'AZURE_SERVICEBUS_ISSUER' # namespace used for converting rules to objects XML_SCHEMA_NAMESPACE = 'http://www.w3.org/2001/XMLSchema-instance' class Queue(WindowsAzureData): ''' Queue class corresponding to Queue Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780773''' def __init__(self, lock_duration=None, max_size_in_megabytes=None, requires_duplicate_detection=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, duplicate_detection_history_time_window=None, max_delivery_count=None, enable_batched_operations=None, size_in_bytes=None, message_count=None): self.lock_duration = lock_duration self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.max_delivery_count = max_delivery_count self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes self.message_count = message_count class Topic(WindowsAzureData): ''' Topic class corresponding to Topic Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780749. ''' def __init__(self, default_message_time_to_live=None, max_size_in_megabytes=None, requires_duplicate_detection=None, duplicate_detection_history_time_window=None, enable_batched_operations=None, size_in_bytes=None): self.default_message_time_to_live = default_message_time_to_live self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes @property def max_size_in_mega_bytes(self): import warnings warnings.warn( 'This attribute has been changed to max_size_in_megabytes.') return self.max_size_in_megabytes @max_size_in_mega_bytes.setter def max_size_in_mega_bytes(self, value): self.max_size_in_megabytes = value class Subscription(WindowsAzureData): ''' Subscription class corresponding to Subscription Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780763. ''' def __init__(self, lock_duration=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, dead_lettering_on_filter_evaluation_exceptions=None, enable_batched_operations=None, max_delivery_count=None, message_count=None): self.lock_duration = lock_duration self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.dead_lettering_on_filter_evaluation_exceptions = \ dead_lettering_on_filter_evaluation_exceptions self.enable_batched_operations = enable_batched_operations self.max_delivery_count = max_delivery_count self.message_count = message_count class Rule(WindowsAzureData): ''' Rule class corresponding to Rule Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780753. ''' def __init__(self, filter_type=None, filter_expression=None, action_type=None, action_expression=None): self.filter_type = filter_type self.filter_expression = filter_expression self.action_type = action_type self.action_expression = action_type class Message(WindowsAzureData): ''' Message class that used in send message/get mesage apis. ''' def __init__(self, body=None, service_bus_service=None, location=None, custom_properties=None, type='application/atom+xml;type=entry;charset=utf-8', broker_properties=None): self.body = body self.location = location self.broker_properties = broker_properties self.custom_properties = custom_properties self.type = type self.service_bus_service = service_bus_service self._topic_name = None self._subscription_name = None self._queue_name = None if not service_bus_service: return # if location is set, then extracts the queue name for queue message and # extracts the topic and subscriptions name if it is topic message. if location: if '/subscriptions/' in location: pos = location.find('/subscriptions/') pos1 = location.rfind('/', 0, pos - 1) self._topic_name = location[pos1 + 1:pos] pos += len('/subscriptions/') pos1 = location.find('/', pos) self._subscription_name = location[pos:pos1] elif '/messages/' in location: pos = location.find('/messages/') pos1 = location.rfind('/', 0, pos - 1) self._queue_name = location[pos1 + 1:pos] def delete(self): ''' Deletes itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.delete_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.delete_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE) def unlock(self): ''' Unlocks itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.unlock_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.unlock_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK) def add_headers(self, request): ''' add addtional headers to request for message request.''' # Adds custom properties if self.custom_properties: for name, value in self.custom_properties.items(): if sys.version_info < (3,) and isinstance(value, unicode): request.headers.append( (name, '"' + value.encode('utf-8') + '"')) elif isinstance(value, str): request.headers.append((name, '"' + str(value) + '"')) elif isinstance(value, datetime): request.headers.append( (name, '"' + value.strftime('%a, %d %b %Y %H:%M:%S GMT') + '"')) else: request.headers.append((name, str(value).lower())) # Adds content-type request.headers.append(('Content-Type', self.type)) # Adds BrokerProperties if self.broker_properties: request.headers.append( ('BrokerProperties', str(self.broker_properties))) return request.headers def _create_message(response, service_instance): ''' Create message from response. response: response from service bus cloud server. service_instance: the service bus client. ''' respbody = response.body custom_properties = {} broker_properties = None message_type = None message_location = None # gets all information from respheaders. for name, value in response.headers: if name.lower() == 'brokerproperties': broker_properties = json.loads(value) elif name.lower() == 'content-type': message_type = value elif name.lower() == 'location': message_location = value elif name.lower() not in ['content-type', 'brokerproperties', 'transfer-encoding', 'server', 'location', 'date']: if '"' in value: value = value[1:-1] try: custom_properties[name] = datetime.strptime( value, '%a, %d %b %Y %H:%M:%S GMT') except ValueError: custom_properties[name] = value else: # only int, float or boolean if value.lower() == 'true': custom_properties[name] = True elif value.lower() == 'false': custom_properties[name] = False # int('3.1') doesn't work so need to get float('3.14') first elif str(int(float(value))) == value: custom_properties[name] = int(value) else: custom_properties[name] = float(value) if message_type == None: message = Message( respbody, service_instance, message_location, custom_properties, 'application/atom+xml;type=entry;charset=utf-8', broker_properties) else: message = Message(respbody, service_instance, message_location, custom_properties, message_type, broker_properties) return message # convert functions def _convert_response_to_rule(response): return _convert_xml_to_rule(response.body) def _convert_xml_to_rule(xmlstr): ''' Converts response xml to rule object. The format of xml for rule: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Filter i:type="SqlFilterExpression"> <SqlExpression>MyProperty='XYZ'</SqlExpression> </Filter> <Action i:type="SqlFilterAction"> <SqlExpression>set MyProperty2 = 'ABC'</SqlExpression> </Action> </RuleDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) rule = Rule() for rule_desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RuleDescription'): for xml_filter in _get_child_nodes(rule_desc, 'Filter'): filter_type = xml_filter.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'filter_type', str(filter_type)) if xml_filter.childNodes: for expr in _get_child_nodes(xml_filter, 'SqlExpression'): setattr(rule, 'filter_expression', expr.firstChild.nodeValue) for xml_action in _get_child_nodes(rule_desc, 'Action'): action_type = xml_action.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'action_type', str(action_type)) if xml_action.childNodes: action_expression = xml_action.childNodes[0].firstChild if action_expression: setattr(rule, 'action_expression', action_expression.nodeValue) # extract id, updated and name value from feed entry and set them of rule. for name, value in _get_entry_properties(xmlstr, True, '/rules').items(): setattr(rule, name, value) return rule def _convert_response_to_queue(response): return _convert_xml_to_queue(response.body) def _parse_bool(value): if value.lower() == 'true': return True return False def _convert_xml_to_queue(xmlstr): ''' Converts xml response to queue object. The format of xml response for queue: <QueueDescription xmlns=\"http://schemas.microsoft.com/netservices/2010/10/servicebus/connect\"> <MaxSizeInBytes>10000</MaxSizeInBytes> <DefaultMessageTimeToLive>PT5M</DefaultMessageTimeToLive> <LockDuration>PT2M</LockDuration> <RequiresGroupedReceives>False</RequiresGroupedReceives> <SupportsDuplicateDetection>False</SupportsDuplicateDetection> ... </QueueDescription> ''' xmldoc = minidom.parseString(xmlstr) queue = Queue() invalid_queue = True # get node for each attribute in Queue class, if nothing found then the # response is not valid xml for Queue. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'QueueDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: queue.lock_duration = node_value invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: queue.max_size_in_megabytes = int(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: queue.requires_duplicate_detection = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'RequiresSession') if node_value is not None: queue.requires_session = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: queue.default_message_time_to_live = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: queue.dead_lettering_on_message_expiration = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: queue.duplicate_detection_history_time_window = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: queue.enable_batched_operations = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxDeliveryCount') if node_value is not None: queue.max_delivery_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MessageCount') if node_value is not None: queue.message_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: queue.size_in_bytes = int(node_value) invalid_queue = False if invalid_queue: raise WindowsAzureError(_ERROR_QUEUE_NOT_FOUND) # extract id, updated and name value from feed entry and set them of queue. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(queue, name, value) return queue def _convert_response_to_topic(response): return _convert_xml_to_topic(response.body) def _convert_xml_to_topic(xmlstr): '''Converts xml response to topic The xml format for topic: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <MaxSizeInMegabytes>1024</MaxSizeInMegabytes> <RequiresDuplicateDetection>false</RequiresDuplicateDetection> <DuplicateDetectionHistoryTimeWindow>P7D</DuplicateDetectionHistoryTimeWindow> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </TopicDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) topic = Topic() invalid_topic = True # get node for each attribute in Topic class, if nothing found then the # response is not valid xml for Topic. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'TopicDescription'): invalid_topic = True node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: topic.default_message_time_to_live = node_value invalid_topic = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: topic.max_size_in_megabytes = int(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: topic.requires_duplicate_detection = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: topic.duplicate_detection_history_time_window = node_value invalid_topic = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: topic.enable_batched_operations = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: topic.size_in_bytes = int(node_value) invalid_topic = False if invalid_topic: raise WindowsAzureError(_ERROR_TOPIC_NOT_FOUND) # extract id, updated and name value from feed entry and set them of topic. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(topic, name, value) return topic def _convert_response_to_subscription(response): return _convert_xml_to_subscription(response.body) def _convert_xml_to_subscription(xmlstr): '''Converts xml response to subscription The xml format for subscription: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <LockDuration>PT5M</LockDuration> <RequiresSession>false</RequiresSession> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <DeadLetteringOnMessageExpiration>false</DeadLetteringOnMessageExpiration> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </SubscriptionDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) subscription = Subscription() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'SubscriptionDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: subscription.lock_duration = node_value node_value = _get_first_child_node_value( desc, 'RequiresSession') if node_value is not None: subscription.requires_session = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: subscription.default_message_time_to_live = node_value node_value = _get_first_child_node_value( desc, 'DeadLetteringOnFilterEvaluationExceptions') if node_value is not None: subscription.dead_lettering_on_filter_evaluation_exceptions = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: subscription.dead_lettering_on_message_expiration = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: subscription.enable_batched_operations = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'MaxDeliveryCount') if node_value is not None: subscription.max_delivery_count = int(node_value) node_value = _get_first_child_node_value( desc, 'MessageCount') if node_value is not None: subscription.message_count = int(node_value) for name, value in _get_entry_properties(xmlstr, True, '/subscriptions').items(): setattr(subscription, name, value) return subscription def _convert_subscription_to_xml(subscription): ''' Converts a subscription object to xml to send. The order of each field of subscription in xml is very important so we can't simple call convert_class_to_xml. subscription: the subsciption object to be converted. ''' subscription_body = '<SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if subscription: if subscription.lock_duration is not None: subscription_body += ''.join( ['<LockDuration>', str(subscription.lock_duration), '</LockDuration>']) if subscription.requires_session is not None: subscription_body += ''.join( ['<RequiresSession>', str(subscription.requires_session).lower(), '</RequiresSession>']) if subscription.default_message_time_to_live is not None: subscription_body += ''.join( ['<DefaultMessageTimeToLive>', str(subscription.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if subscription.dead_lettering_on_message_expiration is not None: subscription_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(subscription.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if subscription.dead_lettering_on_filter_evaluation_exceptions is not None: subscription_body += ''.join( ['<DeadLetteringOnFilterEvaluationExceptions>', str(subscription.dead_lettering_on_filter_evaluation_exceptions).lower(), '</DeadLetteringOnFilterEvaluationExceptions>']) if subscription.enable_batched_operations is not None: subscription_body += ''.join( ['<EnableBatchedOperations>', str(subscription.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if subscription.max_delivery_count is not None: subscription_body += ''.join( ['<MaxDeliveryCount>', str(subscription.max_delivery_count), '</MaxDeliveryCount>']) if subscription.message_count is not None: subscription_body += ''.join( ['<MessageCount>', str(subscription.message_count), '</MessageCount>']) subscription_body += '</SubscriptionDescription>' return _create_entry(subscription_body) def _convert_rule_to_xml(rule): ''' Converts a rule object to xml to send. The order of each field of rule in xml is very important so we cann't simple call convert_class_to_xml. rule: the rule object to be converted. ''' rule_body = '<RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if rule: if rule.filter_type: rule_body += ''.join( ['<Filter i:type="', xml_escape(rule.filter_type), '">']) if rule.filter_type == 'CorrelationFilter': rule_body += ''.join( ['<CorrelationId>', xml_escape(rule.filter_expression), '</CorrelationId>']) else: rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.filter_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Filter>' if rule.action_type: rule_body += ''.join( ['<Action i:type="', xml_escape(rule.action_type), '">']) if rule.action_type == 'SqlRuleAction': rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.action_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Action>' rule_body += '</RuleDescription>' return _create_entry(rule_body) def _convert_topic_to_xml(topic): ''' Converts a topic object to xml to send. The order of each field of topic in xml is very important so we cann't simple call convert_class_to_xml. topic: the topic object to be converted. ''' topic_body = '<TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if topic: if topic.default_message_time_to_live is not None: topic_body += ''.join( ['<DefaultMessageTimeToLive>', str(topic.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if topic.max_size_in_megabytes is not None: topic_body += ''.join( ['<MaxSizeInMegabytes>', str(topic.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if topic.requires_duplicate_detection is not None: topic_body += ''.join( ['<RequiresDuplicateDetection>', str(topic.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if topic.duplicate_detection_history_time_window is not None: topic_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(topic.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if topic.enable_batched_operations is not None: topic_body += ''.join( ['<EnableBatchedOperations>', str(topic.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if topic.size_in_bytes is not None: topic_body += ''.join( ['<SizeInBytes>', str(topic.size_in_bytes), '</SizeInBytes>']) topic_body += '</TopicDescription>' return _create_entry(topic_body) def _convert_queue_to_xml(queue): ''' Converts a queue object to xml to send. The order of each field of queue in xml is very important so we cann't simple call convert_class_to_xml. queue: the queue object to be converted. ''' queue_body = '<QueueDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if queue: if queue.lock_duration: queue_body += ''.join( ['<LockDuration>', str(queue.lock_duration), '</LockDuration>']) if queue.max_size_in_megabytes is not None: queue_body += ''.join( ['<MaxSizeInMegabytes>', str(queue.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if queue.requires_duplicate_detection is not None: queue_body += ''.join( ['<RequiresDuplicateDetection>', str(queue.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if queue.requires_session is not None: queue_body += ''.join( ['<RequiresSession>', str(queue.requires_session).lower(), '</RequiresSession>']) if queue.default_message_time_to_live is not None: queue_body += ''.join( ['<DefaultMessageTimeToLive>', str(queue.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if queue.dead_lettering_on_message_expiration is not None: queue_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(queue.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if queue.duplicate_detection_history_time_window is not None: queue_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(queue.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if queue.max_delivery_count is not None: queue_body += ''.join( ['<MaxDeliveryCount>', str(queue.max_delivery_count), '</MaxDeliveryCount>']) if queue.enable_batched_operations is not None: queue_body += ''.join( ['<EnableBatchedOperations>', str(queue.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if queue.size_in_bytes is not None: queue_body += ''.join( ['<SizeInBytes>', str(queue.size_in_bytes), '</SizeInBytes>']) if queue.message_count is not None: queue_body += ''.join( ['<MessageCount>', str(queue.message_count), '</MessageCount>']) queue_body += '</QueueDescription>' return _create_entry(queue_body) def _service_bus_error_handler(http_error): ''' Simple error handler for service bus service. ''' return _general_error_handler(http_error) from azure.servicebus.servicebusservice import ServiceBusService ================================================ FILE: CustomScript/azure/servicebus/servicebusservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import datetime import os import time from azure import ( WindowsAzureError, SERVICE_BUS_HOST_BASE, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _sign_string, _str, _unicode_type, _update_request_uri_query, url_quote, url_unquote, _validate_not_none, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicebus import ( AZURE_SERVICEBUS_NAMESPACE, AZURE_SERVICEBUS_ACCESS_KEY, AZURE_SERVICEBUS_ISSUER, _convert_topic_to_xml, _convert_response_to_topic, _convert_queue_to_xml, _convert_response_to_queue, _convert_subscription_to_xml, _convert_response_to_subscription, _convert_rule_to_xml, _convert_response_to_rule, _convert_xml_to_queue, _convert_xml_to_topic, _convert_xml_to_subscription, _convert_xml_to_rule, _create_message, _service_bus_error_handler, ) class ServiceBusService(object): def __init__(self, service_namespace=None, account_key=None, issuer=None, x_ms_version='2011-06-01', host_base=SERVICE_BUS_HOST_BASE, shared_access_key_name=None, shared_access_key_value=None, authentication=None): ''' Initializes the service bus service for a namespace with the specified authentication settings (SAS or ACS). service_namespace: Service bus namespace, required for all operations. If None, the value is set to the AZURE_SERVICEBUS_NAMESPACE env variable. account_key: ACS authentication account key. If None, the value is set to the AZURE_SERVICEBUS_ACCESS_KEY env variable. Note that if both SAS and ACS settings are specified, SAS is used. issuer: ACS authentication issuer. If None, the value is set to the AZURE_SERVICEBUS_ISSUER env variable. Note that if both SAS and ACS settings are specified, SAS is used. x_ms_version: Unused. Kept for backwards compatibility. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. shared_access_key_name: SAS authentication key name. Note that if both SAS and ACS settings are specified, SAS is used. shared_access_key_value: SAS authentication key value. Note that if both SAS and ACS settings are specified, SAS is used. authentication: Instance of authentication class. If this is specified, then ACS and SAS parameters are ignored. ''' self.requestid = None self.service_namespace = service_namespace self.host_base = host_base if not self.service_namespace: self.service_namespace = os.environ.get(AZURE_SERVICEBUS_NAMESPACE) if not self.service_namespace: raise WindowsAzureError('You need to provide servicebus namespace') if authentication: self.authentication = authentication else: if not account_key: account_key = os.environ.get(AZURE_SERVICEBUS_ACCESS_KEY) if not issuer: issuer = os.environ.get(AZURE_SERVICEBUS_ISSUER) if shared_access_key_name and shared_access_key_value: self.authentication = ServiceBusSASAuthentication( shared_access_key_name, shared_access_key_value) elif account_key and issuer: self.authentication = ServiceBusWrapTokenAuthentication( account_key, issuer) else: raise WindowsAzureError( 'You need to provide servicebus access key and Issuer OR shared access key and value') self._httpclient = _HTTPClient(service_instance=self) self._filter = self._httpclient.perform_request # Backwards compatibility: # account_key and issuer used to be stored on the service class, they are # now stored on the authentication class. @property def account_key(self): return self.authentication.account_key @account_key.setter def account_key(self, value): self.authentication.account_key = value @property def issuer(self): return self.authentication.issuer @issuer.setter def issuer(self, value): self.authentication.issuer = value def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = ServiceBusService( service_namespace=self.service_namespace, authentication=self.authentication) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def create_queue(self, queue_name, queue=None, fail_on_exist=False): ''' Creates a new queue. Once created, this queue's resource manifest is immutable. queue_name: Name of the queue to create. queue: Queue object to create. fail_on_exist: Specify whether to throw an exception when the queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.body = _get_request_body(_convert_queue_to_xml(queue)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Deletes an existing queue. This operation will also remove all associated state including messages in the queue. queue_name: Name of the queue to delete. fail_not_exist: Specify whether to throw an exception if the queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue(self, queue_name): ''' Retrieves an existing queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_queue(response) def list_queues(self): ''' Enumerates the queues in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Queues' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_queue) def create_topic(self, topic_name, topic=None, fail_on_exist=False): ''' Creates a new topic. Once created, this topic resource manifest is immutable. topic_name: Name of the topic to create. topic: Topic object to create. fail_on_exist: Specify whether to throw an exception when the topic exists. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.body = _get_request_body(_convert_topic_to_xml(topic)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_topic(self, topic_name, fail_not_exist=False): ''' Deletes an existing topic. This operation will also remove all associated state including associated subscriptions. topic_name: Name of the topic to delete. fail_not_exist: Specify whether throw exception when topic doesn't exist. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_topic(self, topic_name): ''' Retrieves the description for the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_topic(response) def list_topics(self): ''' Retrieves the topics in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Topics' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_topic) def create_rule(self, topic_name, subscription_name, rule_name, rule=None, fail_on_exist=False): ''' Creates a new rule. Once created, this rule's resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. fail_on_exist: Specify whether to throw an exception when the rule exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.body = _get_request_body(_convert_rule_to_xml(rule)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_rule(self, topic_name, subscription_name, rule_name, fail_not_exist=False): ''' Deletes an existing rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule to delete. DEFAULT_RULE_NAME=$Default. Use DEFAULT_RULE_NAME to delete default rule for the subscription. fail_not_exist: Specify whether throw exception when rule doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_rule(self, topic_name, subscription_name, rule_name): ''' Retrieves the description for the specified rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_rule(response) def list_rules(self, topic_name, subscription_name): ''' Retrieves the rules that exist under the specified subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/rules/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_rule) def create_subscription(self, topic_name, subscription_name, subscription=None, fail_on_exist=False): ''' Creates a new subscription. Once created, this subscription resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. fail_on_exist: Specify whether throw exception when subscription exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.body = _get_request_body( _convert_subscription_to_xml(subscription)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_subscription(self, topic_name, subscription_name, fail_not_exist=False): ''' Deletes an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription to delete. fail_not_exist: Specify whether to throw an exception when the subscription doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_subscription(self, topic_name, subscription_name): ''' Gets an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_subscription(response) def list_subscriptions(self, topic_name): ''' Retrieves the subscriptions in the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_subscription) def send_topic_message(self, topic_name, message=None): ''' Enqueues a message into the specified topic. The limit to the number of messages which may be present in the topic is governed by the message size in MaxTopicSizeInBytes. If this message causes the topic to exceed its quota, a quota exceeded error is returned and the message will be rejected. topic_name: Name of the topic. message: Message object containing message body and properties. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only( 'message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' This operation is used to atomically retrieve and lock a message for processing. The message is guaranteed not to be delivered to other receivers during the lock duration period specified in buffer description. Once the lock expires, the message will be available to other receivers (on the same subscription only) during the lock duration period specified in the topic description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Unlock a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' Read and delete a message from a subscription as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the subscription. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def send_queue_message(self, queue_name, message=None): ''' Sends a message into the specified queue. The limit to the number of messages which may be present in the topic is governed by the message size the MaxTopicSizeInMegaBytes. If this message will cause the queue to exceed its quota, a quota exceeded error is returned and the message will be rejected. queue_name: Name of the queue. message: Message object containing message body and properties. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only('message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_queue_message(self, queue_name, timeout='60'): ''' Automically retrieves and locks a message from a queue for processing. The message is guaranteed not to be delivered to other receivers (on the same subscription only) during the lock duration period specified in the queue description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_queue_message(self, queue_name, sequence_number, lock_token): ''' Unlocks a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. queue_name: Name of the queue. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_queue_message(self, queue_name, timeout='60'): ''' Reads and deletes a message from a queue as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_queue_message(self, queue_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the queue. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. queue_name: Name of the queue. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def receive_queue_message(self, queue_name, peek_lock=True, timeout=60): ''' Receive a message from a queue for processing. queue_name: Name of the queue. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_queue_message(queue_name, timeout) else: return self.read_delete_queue_message(queue_name, timeout) def receive_subscription_message(self, topic_name, subscription_name, peek_lock=True, timeout=60): ''' Receive a message from a subscription for processing. topic_name: Name of the topic. subscription_name: Name of the subscription. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_subscription_message(topic_name, subscription_name, timeout) else: return self.read_delete_subscription_message(topic_name, subscription_name, timeout) def _get_host(self): return self.service_namespace + self.host_base def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _service_bus_error_handler(ex) return resp def _update_service_bus_header(self, request): ''' Add additional headers for service bus. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) # Adds authorization header for authentication. self.authentication.sign_request(request, self._httpclient) return request.headers # Token cache for Authentication # Shared by the different instances of ServiceBusWrapTokenAuthentication _tokens = {} class ServiceBusWrapTokenAuthentication: def __init__(self, account_key, issuer): self.account_key = account_key self.issuer = issuer def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): ''' return the signed string with token. ''' return 'WRAP access_token="' + \ self._get_token(request.host, request.path, httpclient) + '"' def _token_is_expired(self, token): ''' Check if token expires or not. ''' time_pos_begin = token.find('ExpiresOn=') + len('ExpiresOn=') time_pos_end = token.find('&', time_pos_begin) token_expire_time = int(token[time_pos_begin:time_pos_end]) time_now = time.mktime(time.localtime()) # Adding 30 seconds so the token wouldn't be expired when we send the # token to server. return (token_expire_time - time_now) < 30 def _get_token(self, host, path, httpclient): ''' Returns token for the request. host: the service bus service request. path: the service bus service request. ''' wrap_scope = 'http://' + host + path + self.issuer + self.account_key # Check whether has unexpired cache, return cached token if it is still # usable. if wrap_scope in _tokens: token = _tokens[wrap_scope] if not self._token_is_expired(token): return token # get token from accessconstrol server request = HTTPRequest() request.protocol_override = 'https' request.host = host.replace('.servicebus.', '-sb.accesscontrol.') request.method = 'POST' request.path = '/WRAPv0.9' request.body = ('wrap_name=' + url_quote(self.issuer) + '&wrap_password=' + url_quote(self.account_key) + '&wrap_scope=' + url_quote('http://' + host + path)).encode('utf-8') request.headers.append(('Content-Length', str(len(request.body)))) resp = httpclient.perform_request(request) token = resp.body.decode('utf-8') token = url_unquote(token[token.find('=') + 1:token.rfind('&')]) _tokens[wrap_scope] = token return token class ServiceBusSASAuthentication: def __init__(self, key_name, key_value): self.key_name = key_name self.key_value = key_value def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): uri = httpclient.get_uri(request) uri = url_quote(uri, '').lower() expiry = str(self._get_expiry()) to_sign = uri + '\n' + expiry signature = url_quote(_sign_string(self.key_value, to_sign, False), '') auth_format = 'SharedAccessSignature sig={0}&se={1}&skn={2}&sr={3}' auth = auth_format.format(signature, expiry, self.key_name, uri) return auth def _get_expiry(self): '''Returns the UTC datetime, in seconds since Epoch, when this signed request expires (5 minutes from now).''' return int(round(time.time() + 300)) ================================================ FILE: CustomScript/azure/servicemanagement/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from xml.dom import minidom from azure import ( WindowsAzureData, _Base64String, _create_entry, _dict_of, _encode_base64, _general_error_handler, _get_children_from_path, _get_first_child_node_value, _list_of, _scalar_list_of, _str, _xml_attribute, ) #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_MANAGEMENT_CERTFILE = 'AZURE_MANAGEMENT_CERTFILE' AZURE_MANAGEMENT_SUBSCRIPTIONID = 'AZURE_MANAGEMENT_SUBSCRIPTIONID' # x-ms-version for service management. X_MS_VERSION = '2013-06-01' #----------------------------------------------------------------------------- # Data classes class StorageServices(WindowsAzureData): def __init__(self): self.storage_services = _list_of(StorageService) def __iter__(self): return iter(self.storage_services) def __len__(self): return len(self.storage_services) def __getitem__(self, index): return self.storage_services[index] class StorageService(WindowsAzureData): def __init__(self): self.url = '' self.service_name = '' self.storage_service_properties = StorageAccountProperties() self.storage_service_keys = StorageServiceKeys() self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') self.capabilities = _scalar_list_of(str, 'Capability') class StorageAccountProperties(WindowsAzureData): def __init__(self): self.description = u'' self.affinity_group = u'' self.location = u'' self.label = _Base64String() self.status = u'' self.endpoints = _scalar_list_of(str, 'Endpoint') self.geo_replication_enabled = False self.geo_primary_region = u'' self.status_of_primary = u'' self.geo_secondary_region = u'' self.status_of_secondary = u'' self.last_geo_failover_time = u'' self.creation_time = u'' class StorageServiceKeys(WindowsAzureData): def __init__(self): self.primary = u'' self.secondary = u'' class Locations(WindowsAzureData): def __init__(self): self.locations = _list_of(Location) def __iter__(self): return iter(self.locations) def __len__(self): return len(self.locations) def __getitem__(self, index): return self.locations[index] class Location(WindowsAzureData): def __init__(self): self.name = u'' self.display_name = u'' self.available_services = _scalar_list_of(str, 'AvailableService') class AffinityGroup(WindowsAzureData): def __init__(self): self.name = '' self.label = _Base64String() self.description = u'' self.location = u'' self.hosted_services = HostedServices() self.storage_services = StorageServices() self.capabilities = _scalar_list_of(str, 'Capability') class AffinityGroups(WindowsAzureData): def __init__(self): self.affinity_groups = _list_of(AffinityGroup) def __iter__(self): return iter(self.affinity_groups) def __len__(self): return len(self.affinity_groups) def __getitem__(self, index): return self.affinity_groups[index] class HostedServices(WindowsAzureData): def __init__(self): self.hosted_services = _list_of(HostedService) def __iter__(self): return iter(self.hosted_services) def __len__(self): return len(self.hosted_services) def __getitem__(self, index): return self.hosted_services[index] class HostedService(WindowsAzureData): def __init__(self): self.url = u'' self.service_name = u'' self.hosted_service_properties = HostedServiceProperties() self.deployments = Deployments() class HostedServiceProperties(WindowsAzureData): def __init__(self): self.description = u'' self.location = u'' self.affinity_group = u'' self.label = _Base64String() self.status = u'' self.date_created = u'' self.date_last_modified = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class VirtualNetworkSites(WindowsAzureData): def __init__(self): self.virtual_network_sites = _list_of(VirtualNetworkSite) def __iter__(self): return iter(self.virtual_network_sites) def __len__(self): return len(self.virtual_network_sites) def __getitem__(self, index): return self.virtual_network_sites[index] class VirtualNetworkSite(WindowsAzureData): def __init__(self): self.name = u'' self.id = u'' self.affinity_group = u'' self.subnets = Subnets() class Subnets(WindowsAzureData): def __init__(self): self.subnets = _list_of(Subnet) def __iter__(self): return iter(self.subnets) def __len__(self): return len(self.subnets) def __getitem__(self, index): return self.subnets[index] class Subnet(WindowsAzureData): def __init__(self): self.name = u'' self.address_prefix = u'' class Deployments(WindowsAzureData): def __init__(self): self.deployments = _list_of(Deployment) def __iter__(self): return iter(self.deployments) def __len__(self): return len(self.deployments) def __getitem__(self, index): return self.deployments[index] class Deployment(WindowsAzureData): def __init__(self): self.name = u'' self.deployment_slot = u'' self.private_id = u'' self.status = u'' self.label = _Base64String() self.url = u'' self.configuration = _Base64String() self.role_instance_list = RoleInstanceList() self.upgrade_status = UpgradeStatus() self.upgrade_domain_count = u'' self.role_list = RoleList() self.sdk_version = u'' self.input_endpoint_list = InputEndpoints() self.locked = False self.rollback_allowed = False self.persistent_vm_downtime_info = PersistentVMDowntimeInfo() self.created_time = u'' self.virtual_network_name = u'' self.last_modified_time = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class RoleInstanceList(WindowsAzureData): def __init__(self): self.role_instances = _list_of(RoleInstance) def __iter__(self): return iter(self.role_instances) def __len__(self): return len(self.role_instances) def __getitem__(self, index): return self.role_instances[index] class RoleInstance(WindowsAzureData): def __init__(self): self.role_name = u'' self.instance_name = u'' self.instance_status = u'' self.instance_upgrade_domain = 0 self.instance_fault_domain = 0 self.instance_size = u'' self.instance_state_details = u'' self.instance_error_code = u'' self.ip_address = u'' self.instance_endpoints = InstanceEndpoints() self.power_state = u'' self.fqdn = u'' self.host_name = u'' class InstanceEndpoints(WindowsAzureData): def __init__(self): self.instance_endpoints = _list_of(InstanceEndpoint) def __iter__(self): return iter(self.instance_endpoints) def __len__(self): return len(self.instance_endpoints) def __getitem__(self, index): return self.instance_endpoints[index] class InstanceEndpoint(WindowsAzureData): def __init__(self): self.name = u'' self.vip = u'' self.public_port = u'' self.local_port = u'' self.protocol = u'' class UpgradeStatus(WindowsAzureData): def __init__(self): self.upgrade_type = u'' self.current_upgrade_domain_state = u'' self.current_upgrade_domain = u'' class InputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of(InputEndpoint) def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class InputEndpoint(WindowsAzureData): def __init__(self): self.role_name = u'' self.vip = u'' self.port = u'' class RoleList(WindowsAzureData): def __init__(self): self.roles = _list_of(Role) def __iter__(self): return iter(self.roles) def __len__(self): return len(self.roles) def __getitem__(self, index): return self.roles[index] class Role(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class PersistentVMDowntimeInfo(WindowsAzureData): def __init__(self): self.start_time = u'' self.end_time = u'' self.status = u'' class Certificates(WindowsAzureData): def __init__(self): self.certificates = _list_of(Certificate) def __iter__(self): return iter(self.certificates) def __len__(self): return len(self.certificates) def __getitem__(self, index): return self.certificates[index] class Certificate(WindowsAzureData): def __init__(self): self.certificate_url = u'' self.thumbprint = u'' self.thumbprint_algorithm = u'' self.data = u'' class OperationError(WindowsAzureData): def __init__(self): self.code = u'' self.message = u'' class Operation(WindowsAzureData): def __init__(self): self.id = u'' self.status = u'' self.http_status_code = u'' self.error = OperationError() class OperatingSystem(WindowsAzureData): def __init__(self): self.version = u'' self.label = _Base64String() self.is_default = True self.is_active = True self.family = 0 self.family_label = _Base64String() class OperatingSystems(WindowsAzureData): def __init__(self): self.operating_systems = _list_of(OperatingSystem) def __iter__(self): return iter(self.operating_systems) def __len__(self): return len(self.operating_systems) def __getitem__(self, index): return self.operating_systems[index] class OperatingSystemFamily(WindowsAzureData): def __init__(self): self.name = u'' self.label = _Base64String() self.operating_systems = OperatingSystems() class OperatingSystemFamilies(WindowsAzureData): def __init__(self): self.operating_system_families = _list_of(OperatingSystemFamily) def __iter__(self): return iter(self.operating_system_families) def __len__(self): return len(self.operating_system_families) def __getitem__(self, index): return self.operating_system_families[index] class Subscription(WindowsAzureData): def __init__(self): self.subscription_id = u'' self.subscription_name = u'' self.subscription_status = u'' self.account_admin_live_email_id = u'' self.service_admin_live_email_id = u'' self.max_core_count = 0 self.max_storage_accounts = 0 self.max_hosted_services = 0 self.current_core_count = 0 self.current_hosted_services = 0 self.current_storage_accounts = 0 self.max_virtual_network_sites = 0 self.max_local_network_sites = 0 self.max_dns_servers = 0 class AvailabilityResponse(WindowsAzureData): def __init__(self): self.result = False class SubscriptionCertificates(WindowsAzureData): def __init__(self): self.subscription_certificates = _list_of(SubscriptionCertificate) def __iter__(self): return iter(self.subscription_certificates) def __len__(self): return len(self.subscription_certificates) def __getitem__(self, index): return self.subscription_certificates[index] class SubscriptionCertificate(WindowsAzureData): def __init__(self): self.subscription_certificate_public_key = u'' self.subscription_certificate_thumbprint = u'' self.subscription_certificate_data = u'' self.created = u'' class Images(WindowsAzureData): def __init__(self): self.images = _list_of(OSImage) def __iter__(self): return iter(self.images) def __len__(self): return len(self.images) def __getitem__(self, index): return self.images[index] class OSImage(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.category = u'' self.location = u'' self.logical_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.eula = u'' self.description = u'' class Disks(WindowsAzureData): def __init__(self): self.disks = _list_of(Disk) def __iter__(self): return iter(self.disks) def __len__(self): return len(self.disks) def __getitem__(self, index): return self.disks[index] class Disk(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.attached_to = AttachedTo() self.has_operating_system = u'' self.is_corrupted = u'' self.location = u'' self.logical_disk_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.source_image_name = u'' class AttachedTo(WindowsAzureData): def __init__(self): self.hosted_service_name = u'' self.deployment_name = u'' self.role_name = u'' class PersistentVMRole(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' # undocumented self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class ConfigurationSets(WindowsAzureData): def __init__(self): self.configuration_sets = _list_of(ConfigurationSet) def __iter__(self): return iter(self.configuration_sets) def __len__(self): return len(self.configuration_sets) def __getitem__(self, index): return self.configuration_sets[index] class ConfigurationSet(WindowsAzureData): def __init__(self): self.configuration_set_type = u'NetworkConfiguration' self.role_type = u'' self.input_endpoints = ConfigurationSetInputEndpoints() self.subnet_names = _scalar_list_of(str, 'SubnetName') class ConfigurationSetInputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of( ConfigurationSetInputEndpoint, 'InputEndpoint') def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class ConfigurationSetInputEndpoint(WindowsAzureData): ''' Initializes a network configuration input endpoint. name: Specifies the name for the external endpoint. protocol: Specifies the protocol to use to inspect the virtual machine availability status. Possible values are: HTTP, TCP. port: Specifies the external port to use for the endpoint. local_port: Specifies the internal port on which the virtual machine is listening to serve the endpoint. load_balanced_endpoint_set_name: Specifies a name for a set of load-balanced endpoints. Specifying this element for a given endpoint adds it to the set. If you are setting an endpoint to use to connect to the virtual machine via the Remote Desktop, do not set this property. enable_direct_server_return: Specifies whether direct server return load balancing is enabled. ''' def __init__(self, name=u'', protocol=u'', port=u'', local_port=u'', load_balanced_endpoint_set_name=u'', enable_direct_server_return=False): self.enable_direct_server_return = enable_direct_server_return self.load_balanced_endpoint_set_name = load_balanced_endpoint_set_name self.local_port = local_port self.name = name self.port = port self.load_balancer_probe = LoadBalancerProbe() self.protocol = protocol class WindowsConfigurationSet(WindowsAzureData): def __init__(self, computer_name=None, admin_password=None, reset_password_on_first_logon=None, enable_automatic_updates=None, time_zone=None, admin_username=None): self.configuration_set_type = u'WindowsProvisioningConfiguration' self.computer_name = computer_name self.admin_password = admin_password self.admin_username = admin_username self.reset_password_on_first_logon = reset_password_on_first_logon self.enable_automatic_updates = enable_automatic_updates self.time_zone = time_zone self.domain_join = DomainJoin() self.stored_certificate_settings = StoredCertificateSettings() self.win_rm = WinRM() class DomainJoin(WindowsAzureData): def __init__(self): self.credentials = Credentials() self.join_domain = u'' self.machine_object_ou = u'' class Credentials(WindowsAzureData): def __init__(self): self.domain = u'' self.username = u'' self.password = u'' class StoredCertificateSettings(WindowsAzureData): def __init__(self): self.stored_certificate_settings = _list_of(CertificateSetting) def __iter__(self): return iter(self.stored_certificate_settings) def __len__(self): return len(self.stored_certificate_settings) def __getitem__(self, index): return self.stored_certificate_settings[index] class CertificateSetting(WindowsAzureData): ''' Initializes a certificate setting. thumbprint: Specifies the thumbprint of the certificate to be provisioned. The thumbprint must specify an existing service certificate. store_name: Specifies the name of the certificate store from which retrieve certificate. store_location: Specifies the target certificate store location on the virtual machine. The only supported value is LocalMachine. ''' def __init__(self, thumbprint=u'', store_name=u'', store_location=u''): self.thumbprint = thumbprint self.store_name = store_name self.store_location = store_location class WinRM(WindowsAzureData): ''' Contains configuration settings for the Windows Remote Management service on the Virtual Machine. ''' def __init__(self): self.listeners = Listeners() class Listeners(WindowsAzureData): def __init__(self): self.listeners = _list_of(Listener) def __iter__(self): return iter(self.listeners) def __len__(self): return len(self.listeners) def __getitem__(self, index): return self.listeners[index] class Listener(WindowsAzureData): ''' Specifies the protocol and certificate information for the listener. protocol: Specifies the protocol of listener. Possible values are: Http, Https. The value is case sensitive. certificate_thumbprint: Optional. Specifies the certificate thumbprint for the secure connection. If this value is not specified, a self-signed certificate is generated and used for the Virtual Machine. ''' def __init__(self, protocol=u'', certificate_thumbprint=u''): self.protocol = protocol self.certificate_thumbprint = certificate_thumbprint class LinuxConfigurationSet(WindowsAzureData): def __init__(self, host_name=None, user_name=None, user_password=None, disable_ssh_password_authentication=None): self.configuration_set_type = u'LinuxProvisioningConfiguration' self.host_name = host_name self.user_name = user_name self.user_password = user_password self.disable_ssh_password_authentication =\ disable_ssh_password_authentication self.ssh = SSH() class SSH(WindowsAzureData): def __init__(self): self.public_keys = PublicKeys() self.key_pairs = KeyPairs() class PublicKeys(WindowsAzureData): def __init__(self): self.public_keys = _list_of(PublicKey) def __iter__(self): return iter(self.public_keys) def __len__(self): return len(self.public_keys) def __getitem__(self, index): return self.public_keys[index] class PublicKey(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class KeyPairs(WindowsAzureData): def __init__(self): self.key_pairs = _list_of(KeyPair) def __iter__(self): return iter(self.key_pairs) def __len__(self): return len(self.key_pairs) def __getitem__(self, index): return self.key_pairs[index] class KeyPair(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class LoadBalancerProbe(WindowsAzureData): def __init__(self): self.path = u'' self.port = u'' self.protocol = u'' class DataVirtualHardDisks(WindowsAzureData): def __init__(self): self.data_virtual_hard_disks = _list_of(DataVirtualHardDisk) def __iter__(self): return iter(self.data_virtual_hard_disks) def __len__(self): return len(self.data_virtual_hard_disks) def __getitem__(self, index): return self.data_virtual_hard_disks[index] class DataVirtualHardDisk(WindowsAzureData): def __init__(self): self.host_caching = u'' self.disk_label = u'' self.disk_name = u'' self.lun = 0 self.logical_disk_size_in_gb = 0 self.media_link = u'' class OSVirtualHardDisk(WindowsAzureData): def __init__(self, source_image_name=None, media_link=None, host_caching=None, disk_label=None, disk_name=None): self.source_image_name = source_image_name self.media_link = media_link self.host_caching = host_caching self.disk_label = disk_label self.disk_name = disk_name self.os = u'' # undocumented, not used when adding a role class AsynchronousOperationResult(WindowsAzureData): def __init__(self, request_id=None): self.request_id = request_id class ServiceBusRegion(WindowsAzureData): def __init__(self): self.code = u'' self.fullname = u'' class ServiceBusNamespace(WindowsAzureData): def __init__(self): self.name = u'' self.region = u'' self.default_key = u'' self.status = u'' self.created_at = u'' self.acs_management_endpoint = u'' self.servicebus_endpoint = u'' self.connection_string = u'' self.subscription_id = u'' self.enabled = False class WebSpaces(WindowsAzureData): def __init__(self): self.web_space = _list_of(WebSpace) def __iter__(self): return iter(self.web_space) def __len__(self): return len(self.web_space) def __getitem__(self, index): return self.web_space[index] class WebSpace(WindowsAzureData): def __init__(self): self.availability_state = u'' self.geo_location = u'' self.geo_region = u'' self.name = u'' self.plan = u'' self.status = u'' self.subscription = u'' class Sites(WindowsAzureData): def __init__(self): self.site = _list_of(Site) def __iter__(self): return iter(self.site) def __len__(self): return len(self.site) def __getitem__(self, index): return self.site[index] class Site(WindowsAzureData): def __init__(self): self.admin_enabled = False self.availability_state = '' self.compute_mode = '' self.enabled = False self.enabled_host_names = _scalar_list_of(str, 'a:string') self.host_name_ssl_states = HostNameSslStates() self.host_names = _scalar_list_of(str, 'a:string') self.last_modified_time_utc = '' self.name = '' self.repository_site_name = '' self.self_link = '' self.server_farm = '' self.site_mode = '' self.state = '' self.storage_recovery_default_state = '' self.usage_state = '' self.web_space = '' class HostNameSslStates(WindowsAzureData): def __init__(self): self.host_name_ssl_state = _list_of(HostNameSslState) def __iter__(self): return iter(self.host_name_ssl_state) def __len__(self): return len(self.host_name_ssl_state) def __getitem__(self, index): return self.host_name_ssl_state[index] class HostNameSslState(WindowsAzureData): def __init__(self): self.name = u'' self.ssl_state = u'' class PublishData(WindowsAzureData): _xml_name = 'publishData' def __init__(self): self.publish_profiles = _list_of(PublishProfile, 'publishProfile') class PublishProfile(WindowsAzureData): def __init__(self): self.profile_name = _xml_attribute('profileName') self.publish_method = _xml_attribute('publishMethod') self.publish_url = _xml_attribute('publishUrl') self.msdeploysite = _xml_attribute('msdeploySite') self.user_name = _xml_attribute('userName') self.user_pwd = _xml_attribute('userPWD') self.destination_app_url = _xml_attribute('destinationAppUrl') self.sql_server_db_connection_string = _xml_attribute('SQLServerDBConnectionString') self.my_sqldb_connection_string = _xml_attribute('mySQLDBConnectionString') self.hosting_provider_forum_link = _xml_attribute('hostingProviderForumLink') self.control_panel_link = _xml_attribute('controlPanelLink') class QueueDescription(WindowsAzureData): def __init__(self): self.lock_duration = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.requires_session = False self.default_message_time_to_live = u'' self.dead_lettering_on_message_expiration = False self.duplicate_detection_history_time_window = u'' self.max_delivery_count = 0 self.enable_batched_operations = False self.size_in_bytes = 0 self.message_count = 0 self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.auto_delete_on_idle = u'' self.count_details = CountDetails() self.entity_availability_status = u'' class TopicDescription(WindowsAzureData): def __init__(self): self.default_message_time_to_live = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.duplicate_detection_history_time_window = u'' self.enable_batched_operations = False self.size_in_bytes = 0 self.filtering_messages_before_publishing = False self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.count_details = CountDetails() self.subscription_count = 0 class CountDetails(WindowsAzureData): def __init__(self): self.active_message_count = 0 self.dead_letter_message_count = 0 self.scheduled_message_count = 0 self.transfer_message_count = 0 self.transfer_dead_letter_message_count = 0 class NotificationHubDescription(WindowsAzureData): def __init__(self): self.registration_ttl = u'' self.authorization_rules = AuthorizationRules() class AuthorizationRules(WindowsAzureData): def __init__(self): self.authorization_rule = _list_of(AuthorizationRule) def __iter__(self): return iter(self.authorization_rule) def __len__(self): return len(self.authorization_rule) def __getitem__(self, index): return self.authorization_rule[index] class AuthorizationRule(WindowsAzureData): def __init__(self): self.claim_type = u'' self.claim_value = u'' self.rights = _scalar_list_of(str, 'AccessRights') self.created_time = u'' self.modified_time = u'' self.key_name = u'' self.primary_key = u'' self.secondary_keu = u'' class RelayDescription(WindowsAzureData): def __init__(self): self.path = u'' self.listener_type = u'' self.listener_count = 0 self.created_at = u'' self.updated_at = u'' class MetricResponses(WindowsAzureData): def __init__(self): self.metric_response = _list_of(MetricResponse) def __iter__(self): return iter(self.metric_response) def __len__(self): return len(self.metric_response) def __getitem__(self, index): return self.metric_response[index] class MetricResponse(WindowsAzureData): def __init__(self): self.code = u'' self.data = Data() self.message = u'' class Data(WindowsAzureData): def __init__(self): self.display_name = u'' self.end_time = u'' self.name = u'' self.primary_aggregation_type = u'' self.start_time = u'' self.time_grain = u'' self.unit = u'' self.values = Values() class Values(WindowsAzureData): def __init__(self): self.metric_sample = _list_of(MetricSample) def __iter__(self): return iter(self.metric_sample) def __len__(self): return len(self.metric_sample) def __getitem__(self, index): return self.metric_sample[index] class MetricSample(WindowsAzureData): def __init__(self): self.count = 0 self.time_created = u'' self.total = 0 class MetricDefinitions(WindowsAzureData): def __init__(self): self.metric_definition = _list_of(MetricDefinition) def __iter__(self): return iter(self.metric_definition) def __len__(self): return len(self.metric_definition) def __getitem__(self, index): return self.metric_definition[index] class MetricDefinition(WindowsAzureData): def __init__(self): self.display_name = u'' self.metric_availabilities = MetricAvailabilities() self.name = u'' self.primary_aggregation_type = u'' self.unit = u'' class MetricAvailabilities(WindowsAzureData): def __init__(self): self.metric_availability = _list_of(MetricAvailability, 'MetricAvailabilily') def __iter__(self): return iter(self.metric_availability) def __len__(self): return len(self.metric_availability) def __getitem__(self, index): return self.metric_availability[index] class MetricAvailability(WindowsAzureData): def __init__(self): self.retention = u'' self.time_grain = u'' class Servers(WindowsAzureData): def __init__(self): self.server = _list_of(Server) def __iter__(self): return iter(self.server) def __len__(self): return len(self.server) def __getitem__(self, index): return self.server[index] class Server(WindowsAzureData): def __init__(self): self.name = u'' self.administrator_login = u'' self.location = u'' self.fully_qualified_domain_name = u'' self.version = u'' class Database(WindowsAzureData): def __init__(self): self.name = u'' self.type = u'' self.state = u'' self.self_link = u'' self.parent_link = u'' self.id = 0 self.edition = u'' self.collation_name = u'' self.creation_date = u'' self.is_federation_root = False self.is_system_object = False self.max_size_bytes = 0 def _update_management_header(request): ''' Add additional headers for management. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append additional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) return request.headers def _parse_response_for_async_op(response): ''' Extracts request id from response header. ''' if response is None: return None result = AsynchronousOperationResult() if response.headers: for name, value in response.headers: if name.lower() == 'x-ms-request-id': result.request_id = value return result def _management_error_handler(http_error): ''' Simple error handler for management service. ''' return _general_error_handler(http_error) def _lower(text): return text.lower() class _XmlSerializer(object): @staticmethod def create_storage_service_input_to_xml(service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'CreateStorageServiceInput', [('ServiceName', service_name), ('Description', description), ('Label', label, _encode_base64), ('AffinityGroup', affinity_group), ('Location', location), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def update_storage_service_input_to_xml(description, label, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'UpdateStorageServiceInput', [('Description', description), ('Label', label, _encode_base64), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def regenerate_keys_to_xml(key_type): return _XmlSerializer.doc_from_data('RegenerateKeys', [('KeyType', key_type)]) @staticmethod def update_hosted_service_to_xml(label, description, extended_properties): return _XmlSerializer.doc_from_data('UpdateHostedService', [('Label', label, _encode_base64), ('Description', description)], extended_properties) @staticmethod def create_hosted_service_to_xml(service_name, label, description, location, affinity_group, extended_properties): return _XmlSerializer.doc_from_data( 'CreateHostedService', [('ServiceName', service_name), ('Label', label, _encode_base64), ('Description', description), ('Location', location), ('AffinityGroup', affinity_group)], extended_properties) @staticmethod def create_deployment_to_xml(name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties): return _XmlSerializer.doc_from_data( 'CreateDeployment', [('Name', name), ('PackageUrl', package_url), ('Label', label, _encode_base64), ('Configuration', configuration), ('StartDeployment', start_deployment, _lower), ('TreatWarningsAsError', treat_warnings_as_error, _lower)], extended_properties) @staticmethod def swap_deployment_to_xml(production, source_deployment): return _XmlSerializer.doc_from_data( 'Swap', [('Production', production), ('SourceDeployment', source_deployment)]) @staticmethod def update_deployment_status_to_xml(status): return _XmlSerializer.doc_from_data( 'UpdateDeploymentStatus', [('Status', status)]) @staticmethod def change_deployment_to_xml(configuration, treat_warnings_as_error, mode, extended_properties): return _XmlSerializer.doc_from_data( 'ChangeConfiguration', [('Configuration', configuration), ('TreatWarningsAsError', treat_warnings_as_error, _lower), ('Mode', mode)], extended_properties) @staticmethod def upgrade_deployment_to_xml(mode, package_url, configuration, label, role_to_upgrade, force, extended_properties): return _XmlSerializer.doc_from_data( 'UpgradeDeployment', [('Mode', mode), ('PackageUrl', package_url), ('Configuration', configuration), ('Label', label, _encode_base64), ('RoleToUpgrade', role_to_upgrade), ('Force', force, _lower)], extended_properties) @staticmethod def rollback_upgrade_to_xml(mode, force): return _XmlSerializer.doc_from_data( 'RollbackUpdateOrUpgrade', [('Mode', mode), ('Force', force, _lower)]) @staticmethod def walk_upgrade_domain_to_xml(upgrade_domain): return _XmlSerializer.doc_from_data( 'WalkUpgradeDomain', [('UpgradeDomain', upgrade_domain)]) @staticmethod def certificate_file_to_xml(data, certificate_format, password): return _XmlSerializer.doc_from_data( 'CertificateFile', [('Data', data), ('CertificateFormat', certificate_format), ('Password', password)]) @staticmethod def create_affinity_group_to_xml(name, label, description, location): return _XmlSerializer.doc_from_data( 'CreateAffinityGroup', [('Name', name), ('Label', label, _encode_base64), ('Description', description), ('Location', location)]) @staticmethod def update_affinity_group_to_xml(label, description): return _XmlSerializer.doc_from_data( 'UpdateAffinityGroup', [('Label', label, _encode_base64), ('Description', description)]) @staticmethod def subscription_certificate_to_xml(public_key, thumbprint, data): return _XmlSerializer.doc_from_data( 'SubscriptionCertificate', [('SubscriptionCertificatePublicKey', public_key), ('SubscriptionCertificateThumbprint', thumbprint), ('SubscriptionCertificateData', data)]) @staticmethod def os_image_to_xml(label, media_link, name, os): return _XmlSerializer.doc_from_data( 'OSImage', [('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def data_virtual_hard_disk_to_xml(host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link): return _XmlSerializer.doc_from_data( 'DataVirtualHardDisk', [('HostCaching', host_caching), ('DiskLabel', disk_label), ('DiskName', disk_name), ('Lun', lun), ('LogicalDiskSizeInGB', logical_disk_size_in_gb), ('MediaLink', media_link), ('SourceMediaLink', source_media_link)]) @staticmethod def disk_to_xml(has_operating_system, label, media_link, name, os): return _XmlSerializer.doc_from_data( 'Disk', [('HasOperatingSystem', has_operating_system, _lower), ('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def restart_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'RestartRoleOperation', '<OperationType>RestartRoleOperation</OperationType>') @staticmethod def shutdown_role_operation_to_xml(post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRoleOperation'), ('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRoleOperation', xml) @staticmethod def shutdown_roles_operation_to_xml(role_names, post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' xml += _XmlSerializer.data_to_xml( [('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRolesOperation', xml) @staticmethod def start_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'StartRoleOperation', '<OperationType>StartRoleOperation</OperationType>') @staticmethod def start_roles_operation_to_xml(role_names): xml = _XmlSerializer.data_to_xml( [('OperationType', 'StartRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' return _XmlSerializer.doc_from_xml('StartRolesOperation', xml) @staticmethod def windows_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('ComputerName', configuration.computer_name), ('AdminPassword', configuration.admin_password), ('ResetPasswordOnFirstLogon', configuration.reset_password_on_first_logon, _lower), ('EnableAutomaticUpdates', configuration.enable_automatic_updates, _lower), ('TimeZone', configuration.time_zone)]) if configuration.domain_join is not None: xml += '<DomainJoin>' xml += '<Credentials>' xml += _XmlSerializer.data_to_xml( [('Domain', configuration.domain_join.credentials.domain), ('Username', configuration.domain_join.credentials.username), ('Password', configuration.domain_join.credentials.password)]) xml += '</Credentials>' xml += _XmlSerializer.data_to_xml( [('JoinDomain', configuration.domain_join.join_domain), ('MachineObjectOU', configuration.domain_join.machine_object_ou)]) xml += '</DomainJoin>' if configuration.stored_certificate_settings is not None: xml += '<StoredCertificateSettings>' for cert in configuration.stored_certificate_settings: xml += '<CertificateSetting>' xml += _XmlSerializer.data_to_xml( [('StoreLocation', cert.store_location), ('StoreName', cert.store_name), ('Thumbprint', cert.thumbprint)]) xml += '</CertificateSetting>' xml += '</StoredCertificateSettings>' if configuration.win_rm is not None: xml += '<WinRM><Listeners>' for listener in configuration.win_rm.listeners: xml += '<Listener>' xml += _XmlSerializer.data_to_xml( [('Protocol', listener.protocol), ('CertificateThumbprint', listener.certificate_thumbprint)]) xml += '</Listener>' xml += '</Listeners></WinRM>' xml += _XmlSerializer.data_to_xml( [('AdminUsername', configuration.admin_username)]) return xml @staticmethod def linux_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('HostName', configuration.host_name), ('UserName', configuration.user_name), ('UserPassword', configuration.user_password), ('DisableSshPasswordAuthentication', configuration.disable_ssh_password_authentication, _lower)]) if configuration.ssh is not None: xml += '<SSH>' xml += '<PublicKeys>' for key in configuration.ssh.public_keys: xml += '<PublicKey>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</PublicKey>' xml += '</PublicKeys>' xml += '<KeyPairs>' for key in configuration.ssh.key_pairs: xml += '<KeyPair>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</KeyPair>' xml += '</KeyPairs>' xml += '</SSH>' return xml @staticmethod def network_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type)]) xml += '<InputEndpoints>' for endpoint in configuration.input_endpoints: xml += '<InputEndpoint>' xml += _XmlSerializer.data_to_xml( [('LoadBalancedEndpointSetName', endpoint.load_balanced_endpoint_set_name), ('LocalPort', endpoint.local_port), ('Name', endpoint.name), ('Port', endpoint.port)]) if endpoint.load_balancer_probe.path or\ endpoint.load_balancer_probe.port or\ endpoint.load_balancer_probe.protocol: xml += '<LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Path', endpoint.load_balancer_probe.path), ('Port', endpoint.load_balancer_probe.port), ('Protocol', endpoint.load_balancer_probe.protocol)]) xml += '</LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Protocol', endpoint.protocol), ('EnableDirectServerReturn', endpoint.enable_direct_server_return, _lower)]) xml += '</InputEndpoint>' xml += '</InputEndpoints>' xml += '<SubnetNames>' for name in configuration.subnet_names: xml += _XmlSerializer.data_to_xml([('SubnetName', name)]) xml += '</SubnetNames>' return xml @staticmethod def role_to_xml(availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set): xml = _XmlSerializer.data_to_xml([('RoleName', role_name), ('RoleType', role_type)]) xml += '<ConfigurationSets>' if system_configuration_set is not None: xml += '<ConfigurationSet>' if isinstance(system_configuration_set, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( system_configuration_set) elif isinstance(system_configuration_set, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( system_configuration_set) xml += '</ConfigurationSet>' if network_configuration_set is not None: xml += '<ConfigurationSet>' xml += _XmlSerializer.network_configuration_to_xml( network_configuration_set) xml += '</ConfigurationSet>' xml += '</ConfigurationSets>' if availability_set_name is not None: xml += _XmlSerializer.data_to_xml( [('AvailabilitySetName', availability_set_name)]) if data_virtual_hard_disks is not None: xml += '<DataVirtualHardDisks>' for hd in data_virtual_hard_disks: xml += '<DataVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', hd.host_caching), ('DiskLabel', hd.disk_label), ('DiskName', hd.disk_name), ('Lun', hd.lun), ('LogicalDiskSizeInGB', hd.logical_disk_size_in_gb), ('MediaLink', hd.media_link)]) xml += '</DataVirtualHardDisk>' xml += '</DataVirtualHardDisks>' if os_virtual_hard_disk is not None: xml += '<OSVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', os_virtual_hard_disk.host_caching), ('DiskLabel', os_virtual_hard_disk.disk_label), ('DiskName', os_virtual_hard_disk.disk_name), ('MediaLink', os_virtual_hard_disk.media_link), ('SourceImageName', os_virtual_hard_disk.source_image_name)]) xml += '</OSVirtualHardDisk>' if role_size is not None: xml += _XmlSerializer.data_to_xml([('RoleSize', role_size)]) return xml @staticmethod def add_role_to_xml(role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def update_role_to_xml(role_name, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, None) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def capture_role_to_xml(post_capture_action, target_image_name, target_image_label, provisioning_configuration): xml = _XmlSerializer.data_to_xml( [('OperationType', 'CaptureRoleOperation'), ('PostCaptureAction', post_capture_action)]) if provisioning_configuration is not None: xml += '<ProvisioningConfiguration>' if isinstance(provisioning_configuration, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( provisioning_configuration) elif isinstance(provisioning_configuration, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( provisioning_configuration) xml += '</ProvisioningConfiguration>' xml += _XmlSerializer.data_to_xml( [('TargetImageLabel', target_image_label), ('TargetImageName', target_image_name)]) return _XmlSerializer.doc_from_xml('CaptureRoleOperation', xml) @staticmethod def virtual_machine_deployment_to_xml(deployment_name, deployment_slot, label, role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name): xml = _XmlSerializer.data_to_xml([('Name', deployment_name), ('DeploymentSlot', deployment_slot), ('Label', label)]) xml += '<RoleList>' xml += '<Role>' xml += _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) xml += '</Role>' xml += '</RoleList>' if virtual_network_name is not None: xml += _XmlSerializer.data_to_xml( [('VirtualNetworkName', virtual_network_name)]) return _XmlSerializer.doc_from_xml('Deployment', xml) @staticmethod def create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode): xml = '<HostNames xmlns:a="http://schemas.microsoft.com/2003/10/Serialization/Arrays">' for host_name in host_names: xml += '<a:string>{0}</a:string>'.format(host_name) xml += '</HostNames>' xml += _XmlSerializer.data_to_xml( [('Name', website_name), ('ComputeMode', compute_mode), ('ServerFarm', server_farm), ('SiteMode', site_mode)]) xml += '<WebSpaceToCreate>' xml += _XmlSerializer.data_to_xml( [('GeoRegion', geo_region), ('Name', webspace_name), ('Plan', plan)]) xml += '</WebSpaceToCreate>' return _XmlSerializer.doc_from_xml('Site', xml) @staticmethod def data_to_xml(data): '''Creates an xml fragment from the specified data. data: Array of tuples, where first: xml element name second: xml element text third: conversion function ''' xml = '' for element in data: name = element[0] val = element[1] if len(element) > 2: converter = element[2] else: converter = None if val is not None: if converter is not None: text = _str(converter(_str(val))) else: text = _str(val) xml += ''.join(['<', name, '>', text, '</', name, '>']) return xml @staticmethod def doc_from_xml(document_element_name, inner_xml): '''Wraps the specified xml in an xml root element with default azure namespaces''' xml = ''.join(['<', document_element_name, ' xmlns:i="http://www.w3.org/2001/XMLSchema-instance"', ' xmlns="http://schemas.microsoft.com/windowsazure">']) xml += inner_xml xml += ''.join(['</', document_element_name, '>']) return xml @staticmethod def doc_from_data(document_element_name, data, extended_properties=None): xml = _XmlSerializer.data_to_xml(data) if extended_properties is not None: xml += _XmlSerializer.extended_properties_dict_to_xml_fragment( extended_properties) return _XmlSerializer.doc_from_xml(document_element_name, xml) @staticmethod def extended_properties_dict_to_xml_fragment(extended_properties): xml = '' if extended_properties is not None and len(extended_properties) > 0: xml += '<ExtendedProperties>' for key, val in extended_properties.items(): xml += ''.join(['<ExtendedProperty>', '<Name>', _str(key), '</Name>', '<Value>', _str(val), '</Value>', '</ExtendedProperty>']) xml += '</ExtendedProperties>' return xml def _parse_bool(value): if value.lower() == 'true': return True return False class _ServiceBusManagementXmlSerializer(object): @staticmethod def namespace_to_xml(region): '''Converts a service bus namespace description to xml The xml format: <?xml version="1.0" encoding="utf-8" standalone="yes"?> <entry xmlns="http://www.w3.org/2005/Atom"> <content type="application/xml"> <NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Region>West US</Region> </NamespaceDescription> </content> </entry> ''' body = '<NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' body += ''.join(['<Region>', region, '</Region>']) body += '</NamespaceDescription>' return _create_entry(body) @staticmethod def xml_to_namespace(xmlstr): '''Converts xml response to service bus namespace The xml format for namespace: <entry> <id>uuid:00000000-0000-0000-0000-000000000000;id=0000000</id> <title type="text">myunittests 2012-08-22T16:48:10Z myunittests West US 0000000000000000000000000000000000000000000= Active 2012-08-22T16:48:10.217Z https://myunittests-sb.accesscontrol.windows.net/ https://myunittests.servicebus.windows.net/ Endpoint=sb://myunittests.servicebus.windows.net/;SharedSecretIssuer=owner;SharedSecretValue=0000000000000000000000000000000000000000000= 00000000000000000000000000000000 true ''' xmldoc = minidom.parseString(xmlstr) namespace = ServiceBusNamespace() mappings = ( ('Name', 'name', None), ('Region', 'region', None), ('DefaultKey', 'default_key', None), ('Status', 'status', None), ('CreatedAt', 'created_at', None), ('AcsManagementEndpoint', 'acs_management_endpoint', None), ('ServiceBusEndpoint', 'servicebus_endpoint', None), ('ConnectionString', 'connection_string', None), ('SubscriptionId', 'subscription_id', None), ('Enabled', 'enabled', _parse_bool), ) for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceDescription'): for xml_name, field_name, conversion_func in mappings: node_value = _get_first_child_node_value(desc, xml_name) if node_value is not None: if conversion_func is not None: node_value = conversion_func(node_value) setattr(namespace, field_name, node_value) return namespace @staticmethod def xml_to_region(xmlstr): '''Converts xml response to service bus region The xml format for region: uuid:157c311f-081f-4b4a-a0ba-a8f990ffd2a3;id=1756759 2013-04-10T18:25:29Z East Asia East Asia ''' xmldoc = minidom.parseString(xmlstr) region = ServiceBusRegion() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RegionCodeDescription'): node_value = _get_first_child_node_value(desc, 'Code') if node_value is not None: region.code = node_value node_value = _get_first_child_node_value(desc, 'FullName') if node_value is not None: region.fullname = node_value return region @staticmethod def xml_to_namespace_availability(xmlstr): '''Converts xml response to service bus namespace availability The xml format: uuid:9fc7c652-1856-47ab-8d74-cd31502ea8e6;id=3683292 2013-04-16T03:03:37Z false ''' xmldoc = minidom.parseString(xmlstr) availability = AvailabilityResponse() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceAvailability'): node_value = _get_first_child_node_value(desc, 'Result') if node_value is not None: availability.result = _parse_bool(node_value) return availability from azure.servicemanagement.servicemanagementservice import ( ServiceManagementService) from azure.servicemanagement.servicebusmanagementservice import ( ServiceBusManagementService) from azure.servicemanagement.websitemanagementservice import ( WebsiteManagementService) ================================================ FILE: CustomScript/azure/servicemanagement/servicebusmanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _convert_response_to_feeds, _str, _validate_not_none, ) from azure.servicemanagement import ( _ServiceBusManagementXmlSerializer, QueueDescription, TopicDescription, NotificationHubDescription, RelayDescription, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceBusManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceBusManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for service bus ---------------------------------------- def get_regions(self): ''' Get list of available service bus regions. ''' response = self._perform_get( self._get_path('services/serviceBus/Regions/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_region) def list_namespaces(self): ''' List the service bus namespaces defined on the account. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_namespace) def get_namespace(self, name): ''' Get details about a specific namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces', name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace( response.body) def create_namespace(self, name, region): ''' Create a new service bus namespace. name: Name of the service bus namespace to create. region: Region to create the namespace in. ''' _validate_not_none('name', name) return self._perform_put( self._get_path('services/serviceBus/Namespaces', name), _ServiceBusManagementXmlSerializer.namespace_to_xml(region)) def delete_namespace(self, name): ''' Delete a service bus namespace. name: Name of the service bus namespace to delete. ''' _validate_not_none('name', name) return self._perform_delete( self._get_path('services/serviceBus/Namespaces', name), None) def check_namespace_availability(self, name): ''' Checks to see if the specified service bus namespace is available, or if it has already been taken. name: Name of the service bus namespace to validate. ''' _validate_not_none('name', name) response = self._perform_get( self._get_path('services/serviceBus/CheckNamespaceAvailability', None) + '/?namespace=' + _str(name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability( response.body) def list_queues(self, name): ''' Enumerates the queues in the service namespace. name: Name of the service bus namespace. ''' _validate_not_none('name', name) response = self._perform_get( self._get_list_queues_path(name), None) return _convert_response_to_feeds(response, QueueDescription) def list_topics(self, name): ''' Retrieves the topics in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_topics_path(name), None) return _convert_response_to_feeds(response, TopicDescription) def list_notification_hubs(self, name): ''' Retrieves the notification hubs in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_notification_hubs_path(name), None) return _convert_response_to_feeds(response, NotificationHubDescription) def list_relays(self, name): ''' Retrieves the relays in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_relays_path(name), None) return _convert_response_to_feeds(response, RelayDescription) #--Helper functions -------------------------------------------------- def _get_list_queues_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Queues' def _get_list_topics_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Topics' def _get_list_notification_hubs_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/NotificationHubs' def _get_list_relays_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Relays' ================================================ FILE: CustomScript/azure/servicemanagement/servicemanagementclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os from azure import ( WindowsAzureError, MANAGEMENT_HOST, _get_request_body, _parse_response, _str, _update_request_uri_query, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicemanagement import ( AZURE_MANAGEMENT_CERTFILE, AZURE_MANAGEMENT_SUBSCRIPTIONID, _management_error_handler, _parse_response_for_async_op, _update_management_header, ) class _ServiceManagementClient(object): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): self.requestid = None self.subscription_id = subscription_id self.cert_file = cert_file self.host = host if not self.cert_file: if AZURE_MANAGEMENT_CERTFILE in os.environ: self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE] if not self.subscription_id: if AZURE_MANAGEMENT_SUBSCRIPTIONID in os.environ: self.subscription_id = os.environ[ AZURE_MANAGEMENT_SUBSCRIPTIONID] if not self.cert_file or not self.subscription_id: raise WindowsAzureError( 'You need to provide subscription id and certificate file') self._httpclient = _HTTPClient( service_instance=self, cert_file=self.cert_file) self._filter = self._httpclient.perform_request def with_filter(self, filter): '''Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response.''' res = type(self)(self.subscription_id, self.cert_file, self.host) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) #--Helper functions -------------------------------------------------- def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _management_error_handler(ex) return resp def _perform_get(self, path, response_type): request = HTTPRequest() request.method = 'GET' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) return response def _perform_put(self, path, body, async=False): request = HTTPRequest() request.method = 'PUT' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _perform_post(self, path, body, response_type=None, async=False): request = HTTPRequest() request.method = 'POST' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) if async: return _parse_response_for_async_op(response) return None def _perform_delete(self, path, async=False): request = HTTPRequest() request.method = 'DELETE' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _get_path(self, resource, name): path = '/' + self.subscription_id + '/' + resource if name is not None: path += '/' + _str(name) return path ================================================ FILE: CustomScript/azure/servicemanagement/servicemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, MANAGEMENT_HOST, _str, _validate_not_none, ) from azure.servicemanagement import ( AffinityGroups, AffinityGroup, AvailabilityResponse, Certificate, Certificates, DataVirtualHardDisk, Deployment, Disk, Disks, Locations, Operation, HostedService, HostedServices, Images, OperatingSystems, OperatingSystemFamilies, OSImage, PersistentVMRole, StorageService, StorageServices, Subscription, SubscriptionCertificate, SubscriptionCertificates, VirtualNetworkSites, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for storage accounts ----------------------------------- def list_storage_accounts(self): ''' Lists the storage accounts available under the current subscription. ''' return self._perform_get(self._get_storage_service_path(), StorageServices) def get_storage_account_properties(self, service_name): ''' Returns system properties for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get(self._get_storage_service_path(service_name), StorageService) def get_storage_account_keys(self, service_name): ''' Returns the primary and secondary access keys for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path(service_name) + '/keys', StorageService) def regenerate_storage_account_keys(self, service_name, key_type): ''' Regenerates the primary or secondary access key for the specified storage account. service_name: Name of the storage service account. key_type: Specifies which key to regenerate. Valid values are: Primary, Secondary ''' _validate_not_none('service_name', service_name) _validate_not_none('key_type', key_type) return self._perform_post( self._get_storage_service_path( service_name) + '/keys?action=regenerate', _XmlSerializer.regenerate_keys_to_xml( key_type), StorageService) def create_storage_account(self, service_name, description, label, affinity_group=None, location=None, geo_replication_enabled=True, extended_properties=None): ''' Creates a new storage account in Windows Azure. service_name: A name for the storage account that is unique within Windows Azure. Storage account names must be between 3 and 24 characters in length and use numbers and lower-case letters only. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. affinity_group: The name of an existing affinity group in the specified subscription. You can specify either a location or affinity_group, but not both. location: The location where the storage account is created. You can specify either a location or affinity_group, but not both. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('description', description) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post( self._get_storage_service_path(), _XmlSerializer.create_storage_service_input_to_xml( service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties), async=True) def update_storage_account(self, service_name, description=None, label=None, geo_replication_enabled=None, extended_properties=None): ''' Updates the label, the description, and enables or disables the geo-replication status for a storage account in Windows Azure. service_name: Name of the storage service account. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put( self._get_storage_service_path(service_name), _XmlSerializer.update_storage_service_input_to_xml( description, label, geo_replication_enabled, extended_properties)) def delete_storage_account(self, service_name): ''' Deletes the specified storage account from Windows Azure. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_delete( self._get_storage_service_path(service_name)) def check_storage_account_name_availability(self, service_name): ''' Checks to see if the specified storage account name is available, or if it has already been taken. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path() + '/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for hosted services ------------------------------------ def list_hosted_services(self): ''' Lists the hosted services available under the current subscription. ''' return self._perform_get(self._get_hosted_service_path(), HostedServices) def get_hosted_service_properties(self, service_name, embed_detail=False): ''' Retrieves system properties for the specified hosted service. These properties include the service name and service type; the name of the affinity group to which the service belongs, or its location if it is not part of an affinity group; and optionally, information on the service's deployments. service_name: Name of the hosted service. embed_detail: When True, the management service returns properties for all deployments of the service, as well as for the service itself. ''' _validate_not_none('service_name', service_name) _validate_not_none('embed_detail', embed_detail) return self._perform_get( self._get_hosted_service_path(service_name) + '?embed-detail=' + _str(embed_detail).lower(), HostedService) def create_hosted_service(self, service_name, label, description=None, location=None, affinity_group=None, extended_properties=None): ''' Creates a new hosted service in Windows Azure. service_name: A name for the hosted service that is unique within Windows Azure. This name is the DNS prefix name and can be used to access the hosted service. label: A name for the hosted service. The name can be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. description: A description for the hosted service. The description can be up to 1024 characters in length. location: The location where the hosted service will be created. You can specify either a location or affinity_group, but not both. affinity_group: The name of an existing affinity group associated with this subscription. This name is a GUID and can be retrieved by examining the name element of the response body returned by list_affinity_groups. You can specify either a location or affinity_group, but not both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post(self._get_hosted_service_path(), _XmlSerializer.create_hosted_service_to_xml( service_name, label, description, location, affinity_group, extended_properties)) def update_hosted_service(self, service_name, label=None, description=None, extended_properties=None): ''' Updates the label and/or the description for a hosted service in Windows Azure. service_name: Name of the hosted service. label: A name for the hosted service. The name may be up to 100 characters in length. You must specify a value for either Label or Description, or for both. It is recommended that the label be unique within the subscription. The name can be used identify the hosted service for your tracking purposes. description: A description for the hosted service. The description may be up to 1024 characters in length. You must specify a value for either Label or Description, or for both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put(self._get_hosted_service_path(service_name), _XmlSerializer.update_hosted_service_to_xml( label, description, extended_properties)) def delete_hosted_service(self, service_name): ''' Deletes the specified hosted service from Windows Azure. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_delete(self._get_hosted_service_path(service_name)) def get_deployment_by_slot(self, service_name, deployment_slot): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) return self._perform_get( self._get_deployment_path_using_slot( service_name, deployment_slot), Deployment) def get_deployment_by_name(self, service_name, deployment_name): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_get( self._get_deployment_path_using_name( service_name, deployment_name), Deployment) def create_deployment(self, service_name, deployment_slot, name, package_url, label, configuration, start_deployment=False, treat_warnings_as_error=False, extended_properties=None): ''' Uploads a new service package and creates a new deployment on staging or production. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. configuration: The base-64 encoded service configuration file for the deployment. start_deployment: Indicates whether to start the deployment immediately after it is created. If false, the service model is still deployed to the virtual machines but the code is not run immediately. Instead, the service is Suspended until you call Update Deployment Status and set the status to Running, at which time the service will be started. A deployed service still incurs charges, even if it is suspended. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('name', name) _validate_not_none('package_url', package_url) _validate_not_none('label', label) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_slot( service_name, deployment_slot), _XmlSerializer.create_deployment_to_xml( name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties), async=True) def delete_deployment(self, service_name, deployment_name): ''' Deletes the specified deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_delete( self._get_deployment_path_using_name( service_name, deployment_name), async=True) def swap_deployment(self, service_name, production, source_deployment): ''' Initiates a virtual IP swap between the staging and production deployment environments for a service. If the service is currently running in the staging environment, it will be swapped to the production environment. If it is running in the production environment, it will be swapped to staging. service_name: Name of the hosted service. production: The name of the production deployment. source_deployment: The name of the source deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('production', production) _validate_not_none('source_deployment', source_deployment) return self._perform_post(self._get_hosted_service_path(service_name), _XmlSerializer.swap_deployment_to_xml( production, source_deployment), async=True) def change_deployment_configuration(self, service_name, deployment_name, configuration, treat_warnings_as_error=False, mode='Auto', extended_properties=None): ''' Initiates a change to the deployment configuration. service_name: Name of the hosted service. deployment_name: The name of the deployment. configuration: The base-64 encoded service configuration file for the deployment. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=config', _XmlSerializer.change_deployment_to_xml( configuration, treat_warnings_as_error, mode, extended_properties), async=True) def update_deployment_status(self, service_name, deployment_name, status): ''' Initiates a change in deployment status. service_name: Name of the hosted service. deployment_name: The name of the deployment. status: The change to initiate to the deployment status. Possible values include: Running, Suspended ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('status', status) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=status', _XmlSerializer.update_deployment_status_to_xml( status), async=True) def upgrade_deployment(self, service_name, deployment_name, mode, package_url, configuration, label, force, role_to_upgrade=None, extended_properties=None): ''' Initiates an upgrade. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. configuration: The base-64 encoded service configuration file for the deployment. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. role_to_upgrade: The name of the specific role to upgrade. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('package_url', package_url) _validate_not_none('configuration', configuration) _validate_not_none('label', label) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=upgrade', _XmlSerializer.upgrade_deployment_to_xml( mode, package_url, configuration, label, role_to_upgrade, force, extended_properties), async=True) def walk_upgrade_domain(self, service_name, deployment_name, upgrade_domain): ''' Specifies the next upgrade domain to be walked during manual in-place upgrade or configuration change. service_name: Name of the hosted service. deployment_name: The name of the deployment. upgrade_domain: An integer value that identifies the upgrade domain to walk. Upgrade domains are identified with a zero-based index: the first upgrade domain has an ID of 0, the second has an ID of 1, and so on. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('upgrade_domain', upgrade_domain) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=walkupgradedomain', _XmlSerializer.walk_upgrade_domain_to_xml( upgrade_domain), async=True) def rollback_update_or_upgrade(self, service_name, deployment_name, mode, force): ''' Cancels an in progress configuration change (update) or upgrade and returns the deployment to its state before the upgrade or configuration change was started. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: Specifies whether the rollback should proceed automatically. auto - The rollback proceeds without further user input. manual - You must call the Walk Upgrade Domain operation to apply the rollback to each upgrade domain. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=rollback', _XmlSerializer.rollback_upgrade_to_xml( mode, force), async=True) def reboot_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reboot of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reboot', '', async=True) def reimage_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reimage of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reimage', '', async=True) def check_hosted_service_name_availability(self, service_name): ''' Checks to see if the specified hosted service name is available, or if it has already been taken. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for service certificates ------------------------------- def list_service_certificates(self, service_name): ''' Lists all of the service certificates associated with the specified hosted service. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', Certificates) def get_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Returns the public data for the specified X.509 certificate associated with a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint) + '', Certificate) def add_service_certificate(self, service_name, data, certificate_format, password): ''' Adds a certificate to a hosted service. service_name: Name of the hosted service. data: The base-64 encoded form of the pfx file. certificate_format: The service certificate format. The only supported value is pfx. password: The certificate password. ''' _validate_not_none('service_name', service_name) _validate_not_none('data', data) _validate_not_none('certificate_format', certificate_format) _validate_not_none('password', password) return self._perform_post( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', _XmlSerializer.certificate_file_to_xml( data, certificate_format, password), async=True) def delete_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Deletes a service certificate from the certificate store of a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint), async=True) #--Operations for management certificates ---------------------------- def list_management_certificates(self): ''' The List Management Certificates operation lists and returns basic information about all of the management certificates associated with the specified subscription. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. ''' return self._perform_get('/' + self.subscription_id + '/certificates', SubscriptionCertificates) def get_management_certificate(self, thumbprint): ''' The Get Management Certificate operation retrieves information about the management certificate with the specified thumbprint. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumbprint value of the certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/certificates/' + _str(thumbprint), SubscriptionCertificate) def add_management_certificate(self, public_key, thumbprint, data): ''' The Add Management Certificate operation adds a certificate to the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. public_key: A base64 representation of the management certificate public key. thumbprint: The thumb print that uniquely identifies the management certificate. data: The certificate's raw data in base-64 encoded .cer format. ''' _validate_not_none('public_key', public_key) _validate_not_none('thumbprint', thumbprint) _validate_not_none('data', data) return self._perform_post( '/' + self.subscription_id + '/certificates', _XmlSerializer.subscription_certificate_to_xml( public_key, thumbprint, data)) def delete_management_certificate(self, thumbprint): ''' The Delete Management Certificate operation deletes a certificate from the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumb print that uniquely identifies the management certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/certificates/' + _str(thumbprint)) #--Operations for affinity groups ------------------------------------ def list_affinity_groups(self): ''' Lists the affinity groups associated with the specified subscription. ''' return self._perform_get( '/' + self.subscription_id + '/affinitygroups', AffinityGroups) def get_affinity_group_properties(self, affinity_group_name): ''' Returns the system properties associated with the specified affinity group. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_get( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name) + '', AffinityGroup) def create_affinity_group(self, name, label, location, description=None): ''' Creates a new affinity group for the specified subscription. name: A name for the affinity group that is unique to the subscription. label: A name for the affinity group. The name can be up to 100 characters in length. location: The data center location where the affinity group will be created. To list available locations, use the list_location function. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('name', name) _validate_not_none('label', label) _validate_not_none('location', location) return self._perform_post( '/' + self.subscription_id + '/affinitygroups', _XmlSerializer.create_affinity_group_to_xml(name, label, description, location)) def update_affinity_group(self, affinity_group_name, label, description=None): ''' Updates the label and/or the description for an affinity group for the specified subscription. affinity_group_name: The name of the affinity group. label: A name for the affinity group. The name can be up to 100 characters in length. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('affinity_group_name', affinity_group_name) _validate_not_none('label', label) return self._perform_put( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name), _XmlSerializer.update_affinity_group_to_xml(label, description)) def delete_affinity_group(self, affinity_group_name): ''' Deletes an affinity group in the specified subscription. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_delete('/' + self.subscription_id + \ '/affinitygroups/' + \ _str(affinity_group_name)) #--Operations for locations ------------------------------------------ def list_locations(self): ''' Lists all of the data center locations that are valid for your subscription. ''' return self._perform_get('/' + self.subscription_id + '/locations', Locations) #--Operations for tracking asynchronous requests --------------------- def get_operation_status(self, request_id): ''' Returns the status of the specified operation. After calling an asynchronous operation, you can call Get Operation Status to determine whether the operation has succeeded, failed, or is still in progress. request_id: The request ID for the request you wish to track. ''' _validate_not_none('request_id', request_id) return self._perform_get( '/' + self.subscription_id + '/operations/' + _str(request_id), Operation) #--Operations for retrieving operating system information ------------ def list_operating_systems(self): ''' Lists the versions of the guest operating system that are currently available in Windows Azure. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystems', OperatingSystems) def list_operating_system_families(self): ''' Lists the guest operating system families available in Windows Azure, and also lists the operating system versions available for each family. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystemfamilies', OperatingSystemFamilies) #--Operations for retrieving subscription history -------------------- def get_subscription(self): ''' Returns account and resource allocation information on the specified subscription. ''' return self._perform_get('/' + self.subscription_id + '', Subscription) #--Operations for virtual machines ----------------------------------- def get_role(self, service_name, deployment_name, role_name): ''' Retrieves the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_get( self._get_role_path(service_name, deployment_name, role_name), PersistentVMRole) def create_virtual_machine_deployment(self, service_name, deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole', virtual_network_name=None): ''' Provisions a virtual machine based on the supplied configuration. service_name: Name of the hosted service. deployment_name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production label: Specifies an identifier for the deployment. The label can be up to 100 characters long. The label can be used for tracking purposes. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. virtual_network_name: Specifies the name of an existing virtual network to which the deployment will belong. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('label', label) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_deployment_path_using_name(service_name), _XmlSerializer.virtual_machine_deployment_to_xml( deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name), async=True) def add_role(self, service_name, deployment_name, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Adds a virtual machine to an existing deployment. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_role_path(service_name, deployment_name), _XmlSerializer.add_role_to_xml( role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def update_role(self, service_name, deployment_name, role_name, os_virtual_hard_disk=None, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Updates the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_put( self._get_role_path(service_name, deployment_name, role_name), _XmlSerializer.update_role_to_xml( role_name, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def delete_role(self, service_name, deployment_name, role_name): ''' Deletes the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_delete( self._get_role_path(service_name, deployment_name, role_name), async=True) def capture_role(self, service_name, deployment_name, role_name, post_capture_action, target_image_name, target_image_label, provisioning_configuration=None): ''' The Capture Role operation captures a virtual machine image to your image gallery. From the captured image, you can create additional customized virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_capture_action: Specifies the action after capture operation completes. Possible values are: Delete, Reprovision. target_image_name: Specifies the image name of the captured virtual machine. target_image_label: Specifies the friendly name of the captured virtual machine. provisioning_configuration: Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_capture_action', post_capture_action) _validate_not_none('target_image_name', target_image_name) _validate_not_none('target_image_label', target_image_label) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.capture_role_to_xml( post_capture_action, target_image_name, target_image_label, provisioning_configuration), async=True) def start_role(self, service_name, deployment_name, role_name): ''' Starts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.start_role_operation_to_xml(), async=True) def start_roles(self, service_name, deployment_name, role_names): ''' Starts the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.start_roles_operation_to_xml(role_names), async=True) def restart_role(self, service_name, deployment_name, role_name): ''' Restarts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.restart_role_operation_to_xml( ), async=True) def shutdown_role(self, service_name, deployment_name, role_name, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.shutdown_role_operation_to_xml(post_shutdown_action), async=True) def shutdown_roles(self, service_name, deployment_name, role_names, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.shutdown_roles_operation_to_xml( role_names, post_shutdown_action), async=True) #--Operations for virtual machine images ----------------------------- def list_os_images(self): ''' Retrieves a list of the OS images from the image repository. ''' return self._perform_get(self._get_image_path(), Images) def get_os_image(self, image_name): ''' Retrieves an OS image from the image repository. ''' return self._perform_get(self._get_image_path(image_name), OSImage) def add_os_image(self, label, media_link, name, os): ''' Adds an OS image that is currently stored in a storage account in your subscription to the image repository. label: Specifies the friendly name of the image. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more virtual machines. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_image_path(), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def update_os_image(self, image_name, label, media_link, name, os): ''' Updates an OS image that in your image repository. image_name: The name of the image to update. label: Specifies the friendly name of the image to be updated. You cannot use this operation to update images provided by the Windows Azure platform. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more VM Roles. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('image_name', image_name) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_image_path(image_name), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def delete_os_image(self, image_name, delete_vhd=False): ''' Deletes the specified OS image from your image repository. image_name: The name of the image. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('image_name', image_name) path = self._get_image_path(image_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def get_data_disk(self, service_name, deployment_name, role_name, lun): ''' Retrieves the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_get( self._get_data_disk_path( service_name, deployment_name, role_name, lun), DataVirtualHardDisk) def add_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None, source_media_link=None): ''' Adds a data disk to a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. source_media_link: Specifies the location of a blob in account storage which is mounted as a data disk when the virtual machine is created. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_post( self._get_data_disk_path(service_name, deployment_name, role_name), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link), async=True) def update_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, updated_lun=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None): ''' Updates the specified data disk attached to the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd updated_lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_put( self._get_data_disk_path( service_name, deployment_name, role_name, lun), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, updated_lun, logical_disk_size_in_gb, media_link, None), async=True) def delete_data_disk(self, service_name, deployment_name, role_name, lun, delete_vhd=False): ''' Removes the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) path = self._get_data_disk_path(service_name, deployment_name, role_name, lun) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def list_disks(self): ''' Retrieves a list of the disks in your image repository. ''' return self._perform_get(self._get_disk_path(), Disks) def get_disk(self, disk_name): ''' Retrieves a disk from your image repository. ''' return self._perform_get(self._get_disk_path(disk_name), Disk) def add_disk(self, has_operating_system, label, media_link, name, os): ''' Adds a disk to the user image repository. The disk can be an OS disk or a data disk. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_disk_path(), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def update_disk(self, disk_name, has_operating_system, label, media_link, name, os): ''' Updates an existing disk in your image repository. disk_name: The name of the disk to update. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('disk_name', disk_name) _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_disk_path(disk_name), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def delete_disk(self, disk_name, delete_vhd=False): ''' Deletes the specified data or operating system disk from your image repository. disk_name: The name of the disk to delete. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('disk_name', disk_name) path = self._get_disk_path(disk_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path) #--Operations for virtual networks ------------------------------ def list_virtual_network_sites(self): ''' Retrieves a list of the virtual networks. ''' return self._perform_get(self._get_virtual_network_site_path(), VirtualNetworkSites) #--Helper functions -------------------------------------------------- def _get_virtual_network_site_path(self): return self._get_path('services/networking/virtualnetwork', None) def _get_storage_service_path(self, service_name=None): return self._get_path('services/storageservices', service_name) def _get_hosted_service_path(self, service_name=None): return self._get_path('services/hostedservices', service_name) def _get_deployment_path_using_slot(self, service_name, slot=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deploymentslots', slot) def _get_deployment_path_using_name(self, service_name, deployment_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments', deployment_name) def _get_role_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles', role_name) def _get_role_instance_operations_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roleinstances', role_name) + '/Operations' def _get_roles_operations_path(self, service_name, deployment_name): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles/Operations', None) def _get_data_disk_path(self, service_name, deployment_name, role_name, lun=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + _str(deployment_name) + '/roles/' + _str(role_name) + '/DataDisks', lun) def _get_disk_path(self, disk_name=None): return self._get_path('services/disks', disk_name) def _get_image_path(self, image_name=None): return self._get_path('services/images', image_name) ================================================ FILE: CustomScript/azure/servicemanagement/sqldatabasemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _parse_service_resources_response, ) from azure.servicemanagement import ( Servers, Database, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class SqlDatabaseManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on SQL Database management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(SqlDatabaseManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for sql servers ---------------------------------------- def list_servers(self): ''' List the SQL servers defined on the account. ''' return self._perform_get(self._get_list_servers_path(), Servers) #--Operations for sql databases ---------------------------------------- def list_databases(self, name): ''' List the SQL databases defined on the specified server name ''' response = self._perform_get(self._get_list_databases_path(name), None) return _parse_service_resources_response(response, Database) #--Helper functions -------------------------------------------------- def _get_list_servers_path(self): return self._get_path('services/sqlservers/servers', None) def _get_list_databases_path(self, name): # *contentview=generic is mandatory* return self._get_path('services/sqlservers/servers/', name) + '/databases?contentview=generic' ================================================ FILE: CustomScript/azure/servicemanagement/websitemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _str, ) from azure.servicemanagement import ( WebSpaces, WebSpace, Sites, Site, MetricResponses, MetricDefinitions, PublishData, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class WebsiteManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on WebSite management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(WebsiteManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for web sites ---------------------------------------- def list_webspaces(self): ''' List the webspaces defined on the account. ''' return self._perform_get(self._get_list_webspaces_path(), WebSpaces) def get_webspace(self, webspace_name): ''' Get details of a specific webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_webspace_details_path(webspace_name), WebSpace) def list_sites(self, webspace_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_sites_path(webspace_name), Sites) def get_site(self, webspace_name, website_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_sites_details_path(webspace_name, website_name), Site) def create_site(self, webspace_name, website_name, geo_region, host_names, plan='VirtualDedicatedPlan', compute_mode='Shared', server_farm=None, site_mode=None): ''' Create a website. webspace_name: The name of the webspace. website_name: The name of the website. geo_region: The geographical region of the webspace that will be created. host_names: An array of fully qualified domain names for website. Only one hostname can be specified in the azurewebsites.net domain. The hostname should match the name of the website. Custom domains can only be specified for Shared or Standard websites. plan: This value must be 'VirtualDedicatedPlan'. compute_mode: This value should be 'Shared' for the Free or Paid Shared offerings, or 'Dedicated' for the Standard offering. The default value is 'Shared'. If you set it to 'Dedicated', you must specify a value for the server_farm parameter. server_farm: The name of the Server Farm associated with this website. This is a required value for Standard mode. site_mode: Can be None, 'Limited' or 'Basic'. This value is 'Limited' for the Free offering, and 'Basic' for the Paid Shared offering. Standard mode does not use the site_mode parameter; it uses the compute_mode parameter. ''' xml = _XmlSerializer.create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode) return self._perform_post( self._get_sites_path(webspace_name), xml, Site) def delete_site(self, webspace_name, website_name, delete_empty_server_farm=False, delete_metrics=False): ''' Delete a website. webspace_name: The name of the webspace. website_name: The name of the website. delete_empty_server_farm: If the site being deleted is the last web site in a server farm, you can delete the server farm by setting this to True. delete_metrics: To also delete the metrics for the site that you are deleting, you can set this to True. ''' path = self._get_sites_details_path(webspace_name, website_name) query = '' if delete_empty_server_farm: query += '&deleteEmptyServerFarm=true' if delete_metrics: query += '&deleteMetrics=true' if query: path = path + '?' + query.lstrip('&') return self._perform_delete(path) def restart_site(self, webspace_name, website_name): ''' Restart a web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_post( self._get_restart_path(webspace_name, website_name), '') def get_historical_usage_metrics(self, webspace_name, website_name, metrics = None, start_time=None, end_time=None, time_grain=None): ''' Get historical usage metrics. webspace_name: The name of the webspace. website_name: The name of the website. metrics: Optional. List of metrics name. Otherwise, all metrics returned. start_time: Optional. An ISO8601 date. Otherwise, current hour is used. end_time: Optional. An ISO8601 date. Otherwise, current time is used. time_grain: Optional. A rollup name, as P1D. OTherwise, default rollup for the metrics is used. More information and metrics name at: http://msdn.microsoft.com/en-us/library/azure/dn166964.aspx ''' metrics = ('names='+','.join(metrics)) if metrics else '' start_time = ('StartTime='+start_time) if start_time else '' end_time = ('EndTime='+end_time) if end_time else '' time_grain = ('TimeGrain='+time_grain) if time_grain else '' parameters = ('&'.join(v for v in (metrics, start_time, end_time, time_grain) if v)) parameters = '?'+parameters if parameters else '' return self._perform_get(self._get_historical_usage_metrics_path(webspace_name, website_name) + parameters, MetricResponses) def get_metric_definitions(self, webspace_name, website_name): ''' Get metric definitions of metrics available of this web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_metric_definitions_path(webspace_name, website_name), MetricDefinitions) def get_publish_profile_xml(self, webspace_name, website_name): ''' Get a site's publish profile as a string webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), None).body.decode("utf-8") def get_publish_profile(self, webspace_name, website_name): ''' Get a site's publish profile as an object webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), PublishData) #--Helper functions -------------------------------------------------- def _get_list_webspaces_path(self): return self._get_path('services/webspaces', None) def _get_webspace_details_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) def _get_sites_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) + '/sites' def _get_sites_details_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) def _get_restart_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/restart/' def _get_historical_usage_metrics_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metrics/' def _get_metric_definitions_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metricdefinitions/' def _get_publishxml_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/publishxml/' ================================================ FILE: CustomScript/azure/storage/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import types from datetime import datetime from xml.dom import minidom from azure import (WindowsAzureData, WindowsAzureError, METADATA_NS, xml_escape, _create_entry, _decode_base64_to_text, _decode_base64_to_bytes, _encode_base64, _fill_data_minidom, _fill_instance_element, _get_child_nodes, _get_child_nodesNS, _get_children_from_path, _get_entry_properties, _general_error_handler, _list_of, _parse_response_for_dict, _sign_string, _unicode_type, _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY, ) # x-ms-version for storage service. X_MS_VERSION = '2012-02-12' class EnumResultsBase(object): ''' base class for EnumResults. ''' def __init__(self): self.prefix = u'' self.marker = u'' self.max_results = 0 self.next_marker = u'' class ContainerEnumResults(EnumResultsBase): ''' Blob Container list. ''' def __init__(self): EnumResultsBase.__init__(self) self.containers = _list_of(Container) def __iter__(self): return iter(self.containers) def __len__(self): return len(self.containers) def __getitem__(self, index): return self.containers[index] class Container(WindowsAzureData): ''' Blob container class. ''' def __init__(self): self.name = u'' self.url = u'' self.properties = Properties() self.metadata = {} class Properties(WindowsAzureData): ''' Blob container's properties class. ''' def __init__(self): self.last_modified = u'' self.etag = u'' class RetentionPolicy(WindowsAzureData): ''' RetentionPolicy in service properties. ''' def __init__(self): self.enabled = False self.__dict__['days'] = None def get_days(self): # convert days to int value return int(self.__dict__['days']) def set_days(self, value): ''' set default days if days is set to empty. ''' self.__dict__['days'] = value days = property(fget=get_days, fset=set_days) class Logging(WindowsAzureData): ''' Logging class in service properties. ''' def __init__(self): self.version = u'1.0' self.delete = False self.read = False self.write = False self.retention_policy = RetentionPolicy() class Metrics(WindowsAzureData): ''' Metrics class in service properties. ''' def __init__(self): self.version = u'1.0' self.enabled = False self.include_apis = None self.retention_policy = RetentionPolicy() class StorageServiceProperties(WindowsAzureData): ''' Storage Service Propeties class. ''' def __init__(self): self.logging = Logging() self.metrics = Metrics() class AccessPolicy(WindowsAzureData): ''' Access Policy class in service properties. ''' def __init__(self, start=u'', expiry=u'', permission='u'): self.start = start self.expiry = expiry self.permission = permission class SignedIdentifier(WindowsAzureData): ''' Signed Identifier class for service properties. ''' def __init__(self): self.id = u'' self.access_policy = AccessPolicy() class SignedIdentifiers(WindowsAzureData): ''' SignedIdentifier list. ''' def __init__(self): self.signed_identifiers = _list_of(SignedIdentifier) def __iter__(self): return iter(self.signed_identifiers) def __len__(self): return len(self.signed_identifiers) def __getitem__(self, index): return self.signed_identifiers[index] class BlobEnumResults(EnumResultsBase): ''' Blob list.''' def __init__(self): EnumResultsBase.__init__(self) self.blobs = _list_of(Blob) self.prefixes = _list_of(BlobPrefix) self.delimiter = '' def __iter__(self): return iter(self.blobs) def __len__(self): return len(self.blobs) def __getitem__(self, index): return self.blobs[index] class BlobResult(bytes): def __new__(cls, blob, properties): return bytes.__new__(cls, blob if blob else b'') def __init__(self, blob, properties): self.properties = properties class Blob(WindowsAzureData): ''' Blob class. ''' def __init__(self): self.name = u'' self.snapshot = u'' self.url = u'' self.properties = BlobProperties() self.metadata = {} class BlobProperties(WindowsAzureData): ''' Blob Properties ''' def __init__(self): self.last_modified = u'' self.etag = u'' self.content_length = 0 self.content_type = u'' self.content_encoding = u'' self.content_language = u'' self.content_md5 = u'' self.xms_blob_sequence_number = 0 self.blob_type = u'' self.lease_status = u'' self.lease_state = u'' self.lease_duration = u'' self.copy_id = u'' self.copy_source = u'' self.copy_status = u'' self.copy_progress = u'' self.copy_completion_time = u'' self.copy_status_description = u'' class BlobPrefix(WindowsAzureData): ''' BlobPrefix in Blob. ''' def __init__(self): self.name = '' class BlobBlock(WindowsAzureData): ''' BlobBlock class ''' def __init__(self, id=None, size=None): self.id = id self.size = size class BlobBlockList(WindowsAzureData): ''' BlobBlockList class ''' def __init__(self): self.committed_blocks = [] self.uncommitted_blocks = [] class PageRange(WindowsAzureData): ''' Page Range for page blob. ''' def __init__(self): self.start = 0 self.end = 0 class PageList(object): ''' Page list for page blob. ''' def __init__(self): self.page_ranges = _list_of(PageRange) def __iter__(self): return iter(self.page_ranges) def __len__(self): return len(self.page_ranges) def __getitem__(self, index): return self.page_ranges[index] class QueueEnumResults(EnumResultsBase): ''' Queue list''' def __init__(self): EnumResultsBase.__init__(self) self.queues = _list_of(Queue) def __iter__(self): return iter(self.queues) def __len__(self): return len(self.queues) def __getitem__(self, index): return self.queues[index] class Queue(WindowsAzureData): ''' Queue class ''' def __init__(self): self.name = u'' self.url = u'' self.metadata = {} class QueueMessagesList(WindowsAzureData): ''' Queue message list. ''' def __init__(self): self.queue_messages = _list_of(QueueMessage) def __iter__(self): return iter(self.queue_messages) def __len__(self): return len(self.queue_messages) def __getitem__(self, index): return self.queue_messages[index] class QueueMessage(WindowsAzureData): ''' Queue message class. ''' def __init__(self): self.message_id = u'' self.insertion_time = u'' self.expiration_time = u'' self.pop_receipt = u'' self.time_next_visible = u'' self.dequeue_count = u'' self.message_text = u'' class Entity(WindowsAzureData): ''' Entity class. The attributes of entity will be created dynamically. ''' pass class EntityProperty(WindowsAzureData): ''' Entity property. contains type and value. ''' def __init__(self, type=None, value=None): self.type = type self.value = value class Table(WindowsAzureData): ''' Only for intellicens and telling user the return type. ''' pass def _parse_blob_enum_results_list(response): respbody = response.body return_obj = BlobEnumResults() doc = minidom.parseString(respbody) for enum_results in _get_child_nodes(doc, 'EnumerationResults'): for child in _get_children_from_path(enum_results, 'Blobs', 'Blob'): return_obj.blobs.append(_fill_instance_element(child, Blob)) for child in _get_children_from_path(enum_results, 'Blobs', 'BlobPrefix'): return_obj.prefixes.append( _fill_instance_element(child, BlobPrefix)) for name, value in vars(return_obj).items(): if name == 'blobs' or name == 'prefixes': continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) return return_obj def _update_storage_header(request): ''' add additional headers for storage request. ''' if request.body: assert isinstance(request.body, bytes) # if it is PUT, POST, MERGE, DELETE, need to add content-lengt to header. if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append addtional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # append x-ms-meta name, values to header for name, value in request.headers: if 'x-ms-meta-name-values' in name and value: for meta_name, meta_value in value.items(): request.headers.append(('x-ms-meta-' + meta_name, meta_value)) request.headers.remove((name, value)) break return request def _update_storage_blob_header(request, account_name, account_key): ''' add additional headers for storage blob request. ''' request = _update_storage_header(request) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append( ('Content-Type', 'application/octet-stream Charset=UTF-8')) request.headers.append(('Authorization', _sign_storage_blob_request(request, account_name, account_key))) return request.headers def _update_storage_queue_header(request, account_name, account_key): ''' add additional headers for storage queue request. ''' return _update_storage_blob_header(request, account_name, account_key) def _update_storage_table_header(request): ''' add additional headers for storage table request. ''' request = _update_storage_header(request) for name, _ in request.headers: if name.lower() == 'content-type': break else: request.headers.append(('Content-Type', 'application/atom+xml')) request.headers.append(('DataServiceVersion', '2.0;NetFx')) request.headers.append(('MaxDataServiceVersion', '2.0;NetFx')) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append(('Date', current_time)) return request.headers def _sign_storage_blob_request(request, account_name, account_key): ''' Returns the signed string for blob request which is used to set Authorization header. This is also used to sign queue request. ''' uri_path = request.path.split('?')[0] # method to sign string_to_sign = request.method + '\n' # get headers to sign headers_to_sign = [ 'content-encoding', 'content-language', 'content-length', 'content-md5', 'content-type', 'date', 'if-modified-since', 'if-match', 'if-none-match', 'if-unmodified-since', 'range'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get x-ms header to sign x_ms_headers = [] for name, value in request.headers: if 'x-ms' in name: x_ms_headers.append((name.lower(), value)) x_ms_headers.sort() for name, value in x_ms_headers: if value: string_to_sign += ''.join([name, ':', value, '\n']) # get account_name and uri path to sign string_to_sign += '/' + account_name + uri_path # get query string to sign if it is not table service query_to_sign = request.query query_to_sign.sort() current_name = '' for name, value in query_to_sign: if value: if current_name != name: string_to_sign += '\n' + name + ':' + value else: string_to_sign += '\n' + ',' + value # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _sign_storage_table_request(request, account_name, account_key): uri_path = request.path.split('?')[0] string_to_sign = request.method + '\n' headers_to_sign = ['content-md5', 'content-type', 'date'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get account_name and uri path to sign string_to_sign += ''.join(['/', account_name, uri_path]) for name, value in request.query: if name == 'comp' and uri_path == '/': string_to_sign += '?comp=' + value break # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _to_python_bool(value): if value.lower() == 'true': return True return False def _to_entity_int(data): int_max = (2 << 30) - 1 if data > (int_max) or data < (int_max + 1) * (-1): return 'Edm.Int64', str(data) else: return 'Edm.Int32', str(data) def _to_entity_bool(value): if value: return 'Edm.Boolean', 'true' return 'Edm.Boolean', 'false' def _to_entity_datetime(value): return 'Edm.DateTime', value.strftime('%Y-%m-%dT%H:%M:%S') def _to_entity_float(value): return 'Edm.Double', str(value) def _to_entity_property(value): if value.type == 'Edm.Binary': return value.type, _encode_base64(value.value) return value.type, str(value.value) def _to_entity_none(value): return None, None def _to_entity_str(value): return 'Edm.String', value # Tables of conversions to and from entity types. We support specific # datatypes, and beyond that the user can use an EntityProperty to get # custom data type support. def _from_entity_binary(value): return EntityProperty('Edm.Binary', _decode_base64_to_bytes(value)) def _from_entity_int(value): return int(value) def _from_entity_datetime(value): format = '%Y-%m-%dT%H:%M:%S' if '.' in value: format = format + '.%f' if value.endswith('Z'): format = format + 'Z' return datetime.strptime(value, format) _ENTITY_TO_PYTHON_CONVERSIONS = { 'Edm.Binary': _from_entity_binary, 'Edm.Int32': _from_entity_int, 'Edm.Int64': _from_entity_int, 'Edm.Double': float, 'Edm.Boolean': _to_python_bool, 'Edm.DateTime': _from_entity_datetime, } # Conversion from Python type to a function which returns a tuple of the # type string and content string. _PYTHON_TO_ENTITY_CONVERSIONS = { int: _to_entity_int, bool: _to_entity_bool, datetime: _to_entity_datetime, float: _to_entity_float, EntityProperty: _to_entity_property, str: _to_entity_str, } if sys.version_info < (3,): _PYTHON_TO_ENTITY_CONVERSIONS.update({ long: _to_entity_int, types.NoneType: _to_entity_none, unicode: _to_entity_str, }) def _convert_entity_to_xml(source): ''' Converts an entity object to xml to send. The entity format is: <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' # construct the entity body included in <m:properties> and </m:properties> entity_body = '<m:properties xml:space="preserve">{properties}</m:properties>' if isinstance(source, WindowsAzureData): source = vars(source) properties_str = '' # set properties type for types we know if value has no type info. # if value has type info, then set the type to value.type for name, value in source.items(): mtype = '' conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(type(value)) if conv is None and sys.version_info >= (3,) and value is None: conv = _to_entity_none if conv is None: raise WindowsAzureError( _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY.format( type(value).__name__)) mtype, value = conv(value) # form the property node properties_str += ''.join(['<d:', name]) if value is None: properties_str += ' m:null="true" />' else: if mtype: properties_str += ''.join([' m:type="', mtype, '"']) properties_str += ''.join(['>', xml_escape(value), '</d:', name, '>']) if sys.version_info < (3,): if isinstance(properties_str, unicode): properties_str = properties_str.encode('utf-8') # generate the entity_body entity_body = entity_body.format(properties=properties_str) xmlstr = _create_entry(entity_body) return xmlstr def _convert_table_to_xml(table_name): ''' Create xml to send for a given table name. Since xml format for table is the same as entity and the only difference is that table has only one property 'TableName', so we just call _convert_entity_to_xml. table_name: the name of the table ''' return _convert_entity_to_xml({'TableName': table_name}) def _convert_block_list_to_xml(block_id_list): ''' Convert a block list to xml to send. block_id_list: a str list containing the block ids that are used in put_block_list. Only get block from latest blocks. ''' if block_id_list is None: return '' xml = '<?xml version="1.0" encoding="utf-8"?><BlockList>' for value in block_id_list: xml += '<Latest>{0}</Latest>'.format(_encode_base64(value)) return xml + '</BlockList>' def _create_blob_result(response): blob_properties = _parse_response_for_dict(response) return BlobResult(response.body, blob_properties) def _convert_response_to_block_list(response): ''' Converts xml response to block list class. ''' blob_block_list = BlobBlockList() xmldoc = minidom.parseString(response.body) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'CommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.committed_blocks.append( BlobBlock(xml_block_id, xml_block_size)) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'UncommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.uncommitted_blocks.append( BlobBlock(xml_block_id, xml_block_size)) return blob_block_list def _remove_prefix(name): colon = name.find(':') if colon != -1: return name[colon + 1:] return name def _convert_response_to_entity(response): if response is None: return response return _convert_xml_to_entity(response.body) def _convert_xml_to_entity(xmlstr): ''' Convert xml response to entity. The format of entity: <entry xmlns:d="http://schemas.microsoft.com/ado/2007/08/dataservices" xmlns:m="http://schemas.microsoft.com/ado/2007/08/dataservices/metadata" xmlns="http://www.w3.org/2005/Atom"> <title /> <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) xml_properties = None for entry in _get_child_nodes(xmldoc, 'entry'): for content in _get_child_nodes(entry, 'content'): # TODO: Namespace xml_properties = _get_child_nodesNS( content, METADATA_NS, 'properties') if not xml_properties: return None entity = Entity() # extract each property node and get the type from attribute and node value for xml_property in xml_properties[0].childNodes: name = _remove_prefix(xml_property.nodeName) # exclude the Timestamp since it is auto added by azure when # inserting entity. We don't want this to mix with real properties if name in ['Timestamp']: continue if xml_property.firstChild: value = xml_property.firstChild.nodeValue else: value = '' isnull = xml_property.getAttributeNS(METADATA_NS, 'null') mtype = xml_property.getAttributeNS(METADATA_NS, 'type') # if not isnull and no type info, then it is a string and we just # need the str type to hold the property. if not isnull and not mtype: _set_entity_attr(entity, name, value) elif isnull == 'true': if mtype: property = EntityProperty(mtype, None) else: property = EntityProperty('Edm.String', None) else: # need an object to hold the property conv = _ENTITY_TO_PYTHON_CONVERSIONS.get(mtype) if conv is not None: property = conv(value) else: property = EntityProperty(mtype, value) _set_entity_attr(entity, name, property) # extract id, updated and name value from feed entry and set them of # rule. for name, value in _get_entry_properties(xmlstr, True).items(): if name in ['etag']: _set_entity_attr(entity, name, value) return entity def _set_entity_attr(entity, name, value): try: setattr(entity, name, value) except UnicodeEncodeError: # Python 2 doesn't support unicode attribute names, so we'll # add them and access them directly through the dictionary entity.__dict__[name] = value def _convert_xml_to_table(xmlstr): ''' Converts the xml response to table class. Simply call convert_xml_to_entity and extract the table name, and add updated and author info ''' table = Table() entity = _convert_xml_to_entity(xmlstr) setattr(table, 'name', entity.TableName) for name, value in _get_entry_properties(xmlstr, False).items(): setattr(table, name, value) return table def _storage_error_handler(http_error): ''' Simple error handler for storage service. ''' return _general_error_handler(http_error) # make these available just from storage. from azure.storage.blobservice import BlobService from azure.storage.queueservice import QueueService from azure.storage.tableservice import TableService from azure.storage.cloudstorageaccount import CloudStorageAccount from azure.storage.sharedaccesssignature import ( SharedAccessSignature, SharedAccessPolicy, Permission, WebResource, ) ================================================ FILE: CustomScript/azure/storage/blobservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, BLOB_SERVICE_HOST_BASE, DEV_BLOB_HOST, _ERROR_VALUE_NEGATIVE, _ERROR_PAGE_BLOB_SIZE_ALIGNMENT, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _parse_simple_list, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_type_bytes, _validate_not_none, ) from azure.http import HTTPRequest from azure.storage import ( Container, ContainerEnumResults, PageList, PageRange, SignedIdentifiers, StorageServiceProperties, _convert_block_list_to_xml, _convert_response_to_block_list, _create_blob_result, _parse_blob_enum_results_list, _update_storage_blob_header, ) from azure.storage.storageclient import _StorageClient from os import path import sys if sys.version_info >= (3,): from io import BytesIO else: from cStringIO import StringIO as BytesIO # Keep this value sync with _ERROR_PAGE_BLOB_SIZE_ALIGNMENT _PAGE_SIZE = 512 class BlobService(_StorageClient): ''' This is the main class managing Blob resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=BLOB_SERVICE_HOST_BASE, dev_host=DEV_BLOB_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to https. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self._BLOB_MAX_DATA_SIZE = 64 * 1024 * 1024 self._BLOB_MAX_CHUNK_DATA_SIZE = 4 * 1024 * 1024 super(BlobService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def make_blob_url(self, container_name, blob_name, account_name=None, protocol=None, host_base=None): ''' Creates the url to access a blob. container_name: Name of container. blob_name: Name of blob. account_name: Name of the storage account. If not specified, uses the account specified when BlobService was initialized. protocol: Protocol to use: 'http' or 'https'. If not specified, uses the protocol specified when BlobService was initialized. host_base: Live host base url. If not specified, uses the host base specified when BlobService was initialized. ''' if not account_name: account_name = self.account_name if not protocol: protocol = self.protocol if not host_base: host_base = self.host_base return '{0}://{1}{2}/{3}/{4}'.format(protocol, account_name, host_base, container_name, blob_name) def list_containers(self, prefix=None, marker=None, maxresults=None, include=None): ''' The List Containers operation returns a list of the containers under the specified account. prefix: Optional. Filters the results to return only containers whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. maxresults: Optional. Specifies the maximum number of containers to return. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. set this parameter to string 'metadata' to get container's metadata. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list(response, ContainerEnumResults, "Containers", Container) def create_container(self, container_name, x_ms_meta_name_values=None, x_ms_blob_public_access=None, fail_on_exist=False): ''' Creates a new container under the specified account. If the container with the same name already exists, the operation fails. container_name: Name of container to create. x_ms_meta_name_values: Optional. A dict with name_value pairs to associate with the container as metadata. Example:{'Category':'test'} x_ms_blob_public_access: Optional. Possible values include: container, blob fail_on_exist: specify whether to throw an exception when the container exists. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def get_container_properties(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata and system properties for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_properties only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def get_container_metadata(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified container. The metadata will be in returned dictionary['x-ms-meta-(name)']. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_container_metadata(self, container_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets one or more user-defined name-value pairs for the specified container. container_name: Name of existing container. x_ms_meta_name_values: A dict containing name, value for metadata. Example: {'category':'test'} x_ms_lease_id: If specified, set_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_container_acl(self, container_name, x_ms_lease_id=None): ''' Gets the permissions for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, SignedIdentifiers) def set_container_acl(self, container_name, signed_identifiers=None, x_ms_blob_public_access=None, x_ms_lease_id=None): ''' Sets the permissions for the specified container. container_name: Name of existing container. signed_identifiers: SignedIdentifers instance x_ms_blob_public_access: Optional. Possible values include: container, blob x_ms_lease_id: If specified, set_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [ ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.body = _get_request_body( _convert_class_to_xml(signed_identifiers)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_container(self, container_name, fail_not_exist=False, x_ms_lease_id=None): ''' Marks the specified container for deletion. container_name: Name of container to delete. fail_not_exist: Specify whether to throw an exception when the container doesn't exist. x_ms_lease_id: Required if the container has an active lease. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def lease_container(self, container_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a lock on a container for delete operations. The lock duration can be 15 to 60 seconds, or can be infinite. container_name: Name of existing container. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the container has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none( x_ms_lease_duration if x_ms_lease_action == 'acquire'\ else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def list_blobs(self, container_name, prefix=None, marker=None, maxresults=None, include=None, delimiter=None): ''' Returns the list of blobs under the specified container. container_name: Name of existing container. prefix: Optional. Filters the results to return only blobs whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a marker value within the response body if the list returned was not complete. The marker value may then be used in a subsequent call to request the next set of list items. The marker value is opaque to the client. maxresults: Optional. Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not specify maxresults or specifies a value greater than 5,000, the server will return up to 5,000 items. Setting maxresults to a value less than or equal to zero results in error response code 400 (Bad Request). include: Optional. Specifies one or more datasets to include in the response. To specify more than one of these options on the URI, you must separate each option with a comma. Valid values are: snapshots: Specifies that snapshots should be included in the enumeration. Snapshots are listed from oldest to newest in the response. metadata: Specifies that blob metadata be returned in the response. uncommittedblobs: Specifies that blobs for which blocks have been uploaded, but which have not been committed using Put Block List (REST API), be included in the response. copy: Version 2012-02-12 and newer. Specifies that metadata related to any current or previous Copy Blob operation should be included in the response. delimiter: Optional. When the request includes this parameter, the operation returns a BlobPrefix element in the response body that acts as a placeholder for all blobs whose names begin with the same substring up to the appearance of the delimiter character. The delimiter may be a single character or a string. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('delimiter', _str_or_none(delimiter)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_blob_enum_results_list(response) def set_blob_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. You can also use this operation to set the default request version for all incoming requests that do not have a version specified. storage_service_properties: a StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_blob_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def get_blob_properties(self, container_name, blob_name, x_ms_lease_id=None): ''' Returns all user-defined metadata, standard HTTP properties, and system properties for the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'HEAD' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def set_blob_properties(self, container_name, blob_name, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_md5=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_lease_id=None): ''' Sets system properties on the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_blob_cache_control: Optional. Modifies the cache control string for the blob. x_ms_blob_content_type: Optional. Sets the blob's content type. x_ms_blob_content_md5: Optional. Sets the blob's MD5 hash. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. x_ms_blob_content_language: Optional. Sets the blob's content language. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=properties' request.headers = [ ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_blob(self, container_name, blob_name, blob, x_ms_blob_type, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_content_length=None, x_ms_blob_sequence_number=None): ''' Creates a new block blob or page blob, or updates the content of an existing block blob. See put_block_blob_from_* and put_page_blob_from_* for high level functions that handle the creation and upload of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: For BlockBlob: Content of blob as bytes (size < 64MB). For larger size, you must call put_block and put_block_list to set content of blob. For PageBlob: Use None and call put_page to set content of blob. x_ms_blob_type: Required. Could be BlockBlob or PageBlob. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_content_length: Required for page blobs. This header specifies the maximum size for the page blob, up to 1 TB. The page blob size must be aligned to a 512-byte boundary. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_blob_type', x_ms_blob_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-blob-type', _str_or_none(x_ms_blob_type)), ('Content-Encoding', _str_or_none(content_encoding)), ('Content-Language', _str_or_none(content_language)), ('Content-MD5', _str_or_none(content_md5)), ('Cache-Control', _str_or_none(cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-blob-content-length', _str_or_none(x_ms_blob_content_length)), ('x-ms-blob-sequence-number', _str_or_none(x_ms_blob_sequence_number)) ] request.body = _get_request_body_bytes_only('blob', blob) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file path, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_file(self, container_name, blob_name, stream, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file/stream, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is optional, but should be supplied for optimal performance. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) if count and count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = stream.read(count) self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, None, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) remain_bytes = count block_ids = [] block_index = 0 index = 0 while True: request_count = self._BLOB_MAX_CHUNK_DATA_SIZE\ if remain_bytes is None else min( remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) index += length remain_bytes = remain_bytes - \ length if remain_bytes else None block_id = '{0:08d}'.format(block_index) self.put_block(container_name, blob_name, data, block_id, x_ms_lease_id=x_ms_lease_id) block_ids.append(block_id) block_index += 1 if progress_callback: progress_callback(index, count) else: break self.put_block_list(container_name, blob_name, block_ids, content_md5, x_ms_blob_cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_meta_name_values, x_ms_lease_id) def put_block_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from an array of bytes, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_not_none('index', index) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index if count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = blob[index: index + count] self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: stream = BytesIO(blob) stream.seek(index) self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_text(self, container_name, blob_name, text, text_encoding='utf-8', content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from str/unicode, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. text: Text to upload to the blob. text_encoding: Encoding to use to convert the text to bytes. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text', text) if not isinstance(text, bytes): _validate_not_none('text_encoding', text_encoding) text = text.encode(text_encoding) self.put_block_blob_from_bytes(container_name, blob_name, text, 0, len(text), content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_page_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file path, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def put_page_blob_from_file(self, container_name, blob_name, stream, count, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file/stream, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is required, a page blob cannot be created if the count is unknown. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) _validate_not_none('count', count) if count < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('count')) if count % _PAGE_SIZE != 0: raise TypeError(_ERROR_PAGE_BLOB_SIZE_ALIGNMENT.format(count)) if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, b'', 'PageBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, count, x_ms_blob_sequence_number) remain_bytes = count page_start = 0 while True: request_count = min(remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) remain_bytes = remain_bytes - length page_end = page_start + length - 1 self.put_page(container_name, blob_name, data, 'bytes={0}-{1}'.format(page_start, page_end), 'update', x_ms_lease_id=x_ms_lease_id) page_start = page_start + length if progress_callback: progress_callback(page_start, count) else: break def put_page_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from an array of bytes, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index stream = BytesIO(blob) stream.seek(index) self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def get_blob(self, container_name, blob_name, snapshot=None, x_ms_range=None, x_ms_lease_id=None, x_ms_range_get_content_md5=None): ''' Reads or downloads a blob from the system, including its metadata and properties. See get_blob_to_* for high level functions that handle the download of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_range: Optional. Return only the bytes of the blob in the specified range. x_ms_lease_id: Required if the blob has an active lease. x_ms_range_get_content_md5: Optional. When this header is set to true and specified together with the Range header, the service returns the MD5 hash for the range, as long as the range is less than or equal to 4 MB in size. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-range-get-content-md5', _str_or_none(x_ms_range_get_content_md5)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request, None) return _create_blob_result(response) def get_blob_to_path(self, container_name, blob_name, file_path, open_mode='wb', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file path, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. file_path: Path of file to write to. open_mode: Mode to use when opening the file. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) _validate_not_none('open_mode', open_mode) with open(file_path, open_mode) as stream: self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) def get_blob_to_file(self, container_name, blob_name, stream, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file/stream, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. stream: Opened file/stream to write to. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) props = self.get_blob_properties(container_name, blob_name) blob_size = int(props['content-length']) if blob_size < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, blob_size) data = self.get_blob(container_name, blob_name, snapshot, x_ms_lease_id=x_ms_lease_id) stream.write(data) if progress_callback: progress_callback(blob_size, blob_size) else: if progress_callback: progress_callback(0, blob_size) index = 0 while index < blob_size: chunk_range = 'bytes={0}-{1}'.format( index, index + self._BLOB_MAX_CHUNK_DATA_SIZE - 1) data = self.get_blob( container_name, blob_name, x_ms_range=chunk_range) length = len(data) index += length if length > 0: stream.write(data) if progress_callback: progress_callback(index, blob_size) if length < self._BLOB_MAX_CHUNK_DATA_SIZE: break else: break def get_blob_to_bytes(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as an array of bytes, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) stream = BytesIO() self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) return stream.getvalue() def get_blob_to_text(self, container_name, blob_name, text_encoding='utf-8', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as unicode text, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. text_encoding: Encoding to use when decoding the blob data. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text_encoding', text_encoding) result = self.get_blob_to_bytes(container_name, blob_name, snapshot, x_ms_lease_id, progress_callback) return result.decode(text_encoding) def get_blob_metadata(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified blob or snapshot. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_blob_metadata(self, container_name, blob_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets user-defined metadata for the specified blob as one or more name-value pairs. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def lease_blob(self, container_name, blob_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a one-minute lock on a blob for write operations. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the blob has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none(x_ms_lease_duration\ if x_ms_lease_action == 'acquire' else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def snapshot_blob(self, container_name, blob_name, x_ms_meta_name_values=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None): ''' Creates a read-only snapshot of a blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Optional. Dict containing name and value pairs. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=snapshot' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-snapshot', 'etag', 'last-modified']) def copy_blob(self, container_name, blob_name, x_ms_copy_source, x_ms_meta_name_values=None, x_ms_source_if_modified_since=None, x_ms_source_if_unmodified_since=None, x_ms_source_if_match=None, x_ms_source_if_none_match=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None, x_ms_source_lease_id=None): ''' Copies a blob to a destination within the storage account. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_copy_source: URL up to 2 KB in length that specifies a blob. A source blob in the same account can be private, but a blob in another account must be public or accept credentials included in this URL, such as a Shared Access Signature. Examples: https://myaccount.blob.core.windows.net/mycontainer/myblob https://myaccount.blob.core.windows.net/mycontainer/myblob?snapshot=<DateTime> x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_source_if_modified_since: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. x_ms_source_if_unmodified_since: Optional. An ETag value. Specify this conditional header to copy the blob only if its ETag does not match the value specified. x_ms_source_if_match: Optional. A DateTime value. Specify this conditional header to copy the blob only if the source blob has been modified since the specified date/time. x_ms_source_if_none_match: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. Snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. x_ms_source_lease_id: Optional. Specify this to perform the Copy Blob operation only if the lease ID given matches the active lease ID of the source blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_source', x_ms_copy_source) if x_ms_copy_source.startswith('/'): # Backwards compatibility for earlier versions of the SDK where # the copy source can be in the following formats: # - Blob in named container: # /accountName/containerName/blobName # - Snapshot in named container: # /accountName/containerName/blobName?snapshot=<DateTime> # - Blob in root container: # /accountName/blobName # - Snapshot in root container: # /accountName/blobName?snapshot=<DateTime> account, _, source =\ x_ms_copy_source.partition('/')[2].partition('/') x_ms_copy_source = self.protocol + '://' + \ account + self.host_base + '/' + source request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-copy-source', _str_or_none(x_ms_copy_source)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-source-if-modified-since', _str_or_none(x_ms_source_if_modified_since)), ('x-ms-source-if-unmodified-since', _str_or_none(x_ms_source_if_unmodified_since)), ('x-ms-source-if-match', _str_or_none(x_ms_source_if_match)), ('x-ms-source-if-none-match', _str_or_none(x_ms_source_if_none_match)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-source-lease-id', _str_or_none(x_ms_source_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def abort_copy_blob(self, container_name, blob_name, x_ms_copy_id, x_ms_lease_id=None): ''' Aborts a pending copy_blob operation, and leaves a destination blob with zero length and full metadata. container_name: Name of destination container. blob_name: Name of destination blob. x_ms_copy_id: Copy identifier provided in the x-ms-copy-id of the original copy_blob operation. x_ms_lease_id: Required if the destination blob has an active infinite lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_id', x_ms_copy_id) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + \ _str(blob_name) + '?comp=copy©id=' + \ _str(x_ms_copy_id) request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-copy-action', 'abort'), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_blob(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Marks the specified blob or snapshot for deletion. The blob is later deleted during garbage collection. To mark a specific snapshot for deletion provide the date/time of the snapshot via the snapshot parameter. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to delete. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block(self, container_name, blob_name, block, blockid, content_md5=None, x_ms_lease_id=None): ''' Creates a new block to be committed as part of a blob. container_name: Name of existing container. blob_name: Name of existing blob. block: Content of the block. blockid: Required. A value that identifies the block. The string must be less than or equal to 64 bytes in size. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block', block) _validate_not_none('blockid', blockid) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=block' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('blockid', _encode_base64(_str_or_none(blockid)))] request.body = _get_request_body_bytes_only('block', block) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_list(self, container_name, blob_name, block_list, content_md5=None, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Writes a blob by specifying the list of block IDs that make up the blob. In order to be written as part of a blob, a block must have been successfully written to the server in a prior Put Block (REST API) operation. container_name: Name of existing container. blob_name: Name of existing blob. block_list: A str list containing the block ids. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_blob_cache_control: Optional. Sets the blob's cache control. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_type: Optional. Sets the blob's content type. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_language: Optional. Set the blob's content language. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_md5: Optional. An MD5 hash of the blob content. Note that this hash is not validated, as the hashes for the individual blocks were validated when each was uploaded. x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block_list', block_list) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.body = _get_request_body( _convert_block_list_to_xml(block_list)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_block_list(self, container_name, blob_name, snapshot=None, blocklisttype=None, x_ms_lease_id=None): ''' Retrieves the list of blocks that have been uploaded as part of a block blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. Datetime to determine the time to retrieve the blocks. blocklisttype: Specifies whether to return the list of committed blocks, the list of uncommitted blocks, or both lists together. Valid values are: committed, uncommitted, or all. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [ ('snapshot', _str_or_none(snapshot)), ('blocklisttype', _str_or_none(blocklisttype)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _convert_response_to_block_list(response) def put_page(self, container_name, blob_name, page, x_ms_range, x_ms_page_write, timeout=None, content_md5=None, x_ms_lease_id=None, x_ms_if_sequence_number_lte=None, x_ms_if_sequence_number_lt=None, x_ms_if_sequence_number_eq=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None): ''' Writes a range of pages to a page blob. container_name: Name of existing container. blob_name: Name of existing blob. page: Content of the page. x_ms_range: Required. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_page_write: Required. You may specify one of the following options: update (lower case): Writes the bytes specified by the request body into the specified range. The Range and Content-Length headers must match to perform the update. clear (lower case): Clears the specified range and releases the space used in storage for that range. To clear a range, set the Content-Length header to zero, and the Range header to a value that indicates the range to clear, up to maximum blob size. timeout: the timeout parameter is expressed in seconds. content_md5: Optional. An MD5 hash of the page content. This hash is used to verify the integrity of the page during transport. When this header is specified, the storage service compares the hash of the content that has arrived with the header value that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). x_ms_lease_id: Required if the blob has an active lease. x_ms_if_sequence_number_lte: Optional. If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_lt: Optional. If the blob's sequence number is less than the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_eq: Optional. If the blob's sequence number is equal to the specified value, the request proceeds; otherwise it fails. if_modified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has been modified since the specified date/time. If the blob has not been modified, the Blob service fails. if_unmodified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has not been modified since the specified date/time. If the blob has been modified, the Blob service fails. if_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value matches the value specified. If the values do not match, the Blob service fails. if_none_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value does not match the value specified. If the values are identical, the Blob service fails. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('page', page) _validate_not_none('x_ms_range', x_ms_range) _validate_not_none('x_ms_page_write', x_ms_page_write) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=page' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('Content-MD5', _str_or_none(content_md5)), ('x-ms-page-write', _str_or_none(x_ms_page_write)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-if-sequence-number-le', _str_or_none(x_ms_if_sequence_number_lte)), ('x-ms-if-sequence-number-lt', _str_or_none(x_ms_if_sequence_number_lt)), ('x-ms-if-sequence-number-eq', _str_or_none(x_ms_if_sequence_number_eq)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)) ] request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body_bytes_only('page', page) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_page_ranges(self, container_name, blob_name, snapshot=None, range=None, x_ms_range=None, x_ms_lease_id=None): ''' Retrieves the page ranges for a blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve information from. range: Optional. Specifies the range of bytes over which to list ranges, inclusively. If omitted, then all ranges for the blob are returned. x_ms_range: Optional. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=pagelist' request.headers = [ ('Range', _str_or_none(range)), ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_simple_list(response, PageList, PageRange, "page_ranges") ================================================ FILE: CustomScript/azure/storage/cloudstorageaccount.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure.storage.blobservice import BlobService from azure.storage.tableservice import TableService from azure.storage.queueservice import QueueService class CloudStorageAccount(object): """ Provides a factory for creating the blob, queue, and table services with a common account name and account key. Users can either use the factory or can construct the appropriate service directly. """ def __init__(self, account_name=None, account_key=None): self.account_name = account_name self.account_key = account_key def create_blob_service(self): return BlobService(self.account_name, self.account_key) def create_table_service(self): return TableService(self.account_name, self.account_key) def create_queue_service(self): return QueueService(self.account_name, self.account_key) ================================================ FILE: CustomScript/azure/storage/queueservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureConflictError, WindowsAzureError, DEV_QUEUE_HOST, QUEUE_SERVICE_HOST_BASE, xml_escape, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, _ERROR_CONFLICT, ) from azure.http import ( HTTPRequest, HTTP_RESPONSE_NO_CONTENT, ) from azure.storage import ( Queue, QueueEnumResults, QueueMessagesList, StorageServiceProperties, _update_storage_queue_header, ) from azure.storage.storageclient import _StorageClient class QueueService(_StorageClient): ''' This is the main class managing queue resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=QUEUE_SERVICE_HOST_BASE, dev_host=DEV_QUEUE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(QueueService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def get_queue_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Queue Service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def list_queues(self, prefix=None, marker=None, maxresults=None, include=None): ''' Lists all of the queues in a given storage account. prefix: Filters the results to return only queues with names that begin with the specified prefix. marker: A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a NextMarker element within the response body if the list returned was not complete. This value may then be used as a query parameter in a subsequent call to request the next portion of the list of queues. The marker value is opaque to the client. maxresults: Specifies the maximum number of queues to return. If maxresults is not specified, the server will return up to 5,000 items. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list( response, QueueEnumResults, "Queues", Queue) def create_queue(self, queue_name, x_ms_meta_name_values=None, fail_on_exist=False): ''' Creates a queue under the given account. queue_name: name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. fail_on_exist: Specify whether throw exception when queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_on_exist: try: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: return False return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(response.message)) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Permanently deletes the specified queue. queue_name: Name of the queue. fail_not_exist: Specify whether throw exception when queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue_metadata(self, queue_name): ''' Retrieves user-defined metadata and queue properties on the specified queue. Metadata is associated with the queue as name-values pairs. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix( response, prefixes=['x-ms-meta', 'x-ms-approximate-messages-count']) def set_queue_metadata(self, queue_name, x_ms_meta_name_values=None): ''' Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. queue_name: Name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def put_message(self, queue_name, message_text, visibilitytimeout=None, messagettl=None): ''' Adds a new message to the back of the message queue. A visibility timeout can also be specified to make the message invisible until the visibility timeout expires. A message must be in a format that can be included in an XML request with UTF-8 encoding. The encoded message can be up to 64KB in size for versions 2011-08-18 and newer, or 8KB in size for previous versions. queue_name: Name of the queue. message_text: Message content. visibilitytimeout: Optional. If not specified, the default value is 0. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibilitytimeout should be set to a value smaller than the time-to-live value. messagettl: Optional. Specifies the time-to-live interval for the message, in seconds. The maximum time-to-live allowed is 7 days. If this parameter is omitted, the default time-to-live is 7 days. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_text', message_text) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('visibilitytimeout', _str_or_none(visibilitytimeout)), ('messagettl', _str_or_none(messagettl)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def get_messages(self, queue_name, numofmessages=None, visibilitytimeout=None): ''' Retrieves one or more messages from the front of the queue. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to retrieve from the queue, up to a maximum of 32. If fewer are visible, the visible messages are returned. By default, a single message is retrieved from the queue with this operation. visibilitytimeout: Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 1 second, and cannot be larger than 7 days, or larger than 2 hours on REST protocol versions prior to version 2011-08-18. The visibility timeout of a message can be set to a value later than the expiry time. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('numofmessages', _str_or_none(numofmessages)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def peek_messages(self, queue_name, numofmessages=None): ''' Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to peek from the queue, up to a maximum of 32. By default, a single message is peeked from the queue with this operation. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages?peekonly=true' request.query = [('numofmessages', _str_or_none(numofmessages))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def delete_message(self, queue_name, message_id, popreceipt): ''' Deletes the specified message. queue_name: Name of the queue. message_id: Message to delete. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('popreceipt', popreceipt) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [('popreceipt', _str_or_none(popreceipt))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def clear_messages(self, queue_name): ''' Deletes all messages from the specified queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def update_message(self, queue_name, message_id, message_text, popreceipt, visibilitytimeout): ''' Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. queue_name: Name of the queue. message_id: Message to update. message_text: Content of message. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. visibilitytimeout: Required. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. A message can be updated until it has been deleted or has expired. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('message_text', message_text) _validate_not_none('popreceipt', popreceipt) _validate_not_none('visibilitytimeout', visibilitytimeout) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [ ('popreceipt', _str_or_none(popreceipt)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-popreceipt', 'x-ms-time-next-visible']) def set_queue_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Queue service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) ================================================ FILE: CustomScript/azure/storage/sharedaccesssignature.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import _sign_string, url_quote from azure.storage import X_MS_VERSION #------------------------------------------------------------------------- # Constants for the share access signature SIGNED_START = 'st' SIGNED_EXPIRY = 'se' SIGNED_RESOURCE = 'sr' SIGNED_PERMISSION = 'sp' SIGNED_IDENTIFIER = 'si' SIGNED_SIGNATURE = 'sig' SIGNED_VERSION = 'sv' RESOURCE_BLOB = 'b' RESOURCE_CONTAINER = 'c' SIGNED_RESOURCE_TYPE = 'resource' SHARED_ACCESS_PERMISSION = 'permission' #-------------------------------------------------------------------------- class WebResource(object): ''' Class that stands for the resource to get the share access signature path: the resource path. properties: dict of name and values. Contains 2 item: resource type and permission request_url: the url of the webresource include all the queries. ''' def __init__(self, path=None, request_url=None, properties=None): self.path = path self.properties = properties or {} self.request_url = request_url class Permission(object): ''' Permission class. Contains the path and query_string for the path. path: the resource path query_string: dict of name, values. Contains SIGNED_START, SIGNED_EXPIRY SIGNED_RESOURCE, SIGNED_PERMISSION, SIGNED_IDENTIFIER, SIGNED_SIGNATURE name values. ''' def __init__(self, path=None, query_string=None): self.path = path self.query_string = query_string class SharedAccessPolicy(object): ''' SharedAccessPolicy class. ''' def __init__(self, access_policy, signed_identifier=None): self.id = signed_identifier self.access_policy = access_policy class SharedAccessSignature(object): ''' The main class used to do the signing and generating the signature. account_name: the storage account name used to generate shared access signature account_key: the access key to genenerate share access signature permission_set: the permission cache used to signed the request url. ''' def __init__(self, account_name, account_key, permission_set=None): self.account_name = account_name self.account_key = account_key self.permission_set = permission_set def generate_signed_query_string(self, path, resource_type, shared_access_policy, version=X_MS_VERSION): ''' Generates the query string for path, resource type and shared access policy. path: the resource resource_type: could be blob or container shared_access_policy: shared access policy version: x-ms-version for storage service, or None to get a signed query string compatible with pre 2012-02-12 clients, where the version is not included in the query string. ''' query_string = {} if shared_access_policy.access_policy.start: query_string[ SIGNED_START] = shared_access_policy.access_policy.start if version: query_string[SIGNED_VERSION] = version query_string[SIGNED_EXPIRY] = shared_access_policy.access_policy.expiry query_string[SIGNED_RESOURCE] = resource_type query_string[ SIGNED_PERMISSION] = shared_access_policy.access_policy.permission if shared_access_policy.id: query_string[SIGNED_IDENTIFIER] = shared_access_policy.id query_string[SIGNED_SIGNATURE] = self._generate_signature( path, shared_access_policy, version) return query_string def sign_request(self, web_resource): ''' sign request to generate request_url with sharedaccesssignature info for web_resource.''' if self.permission_set: for shared_access_signature in self.permission_set: if self._permission_matches_request( shared_access_signature, web_resource, web_resource.properties[ SIGNED_RESOURCE_TYPE], web_resource.properties[SHARED_ACCESS_PERMISSION]): if web_resource.request_url.find('?') == -1: web_resource.request_url += '?' else: web_resource.request_url += '&' web_resource.request_url += self._convert_query_string( shared_access_signature.query_string) break return web_resource def _convert_query_string(self, query_string): ''' Converts query string to str. The order of name, values is very important and can't be wrong.''' convert_str = '' if SIGNED_START in query_string: convert_str += SIGNED_START + '=' + \ url_quote(query_string[SIGNED_START]) + '&' convert_str += SIGNED_EXPIRY + '=' + \ url_quote(query_string[SIGNED_EXPIRY]) + '&' convert_str += SIGNED_PERMISSION + '=' + \ query_string[SIGNED_PERMISSION] + '&' convert_str += SIGNED_RESOURCE + '=' + \ query_string[SIGNED_RESOURCE] + '&' if SIGNED_IDENTIFIER in query_string: convert_str += SIGNED_IDENTIFIER + '=' + \ query_string[SIGNED_IDENTIFIER] + '&' if SIGNED_VERSION in query_string: convert_str += SIGNED_VERSION + '=' + \ query_string[SIGNED_VERSION] + '&' convert_str += SIGNED_SIGNATURE + '=' + \ url_quote(query_string[SIGNED_SIGNATURE]) + '&' return convert_str def _generate_signature(self, path, shared_access_policy, version): ''' Generates signature for a given path and shared access policy. ''' def get_value_to_append(value, no_new_line=False): return_value = '' if value: return_value = value if not no_new_line: return_value += '\n' return return_value if path[0] != '/': path = '/' + path canonicalized_resource = '/' + self.account_name + path # Form the string to sign from shared_access_policy and canonicalized # resource. The order of values is important. string_to_sign = \ (get_value_to_append(shared_access_policy.access_policy.permission) + get_value_to_append(shared_access_policy.access_policy.start) + get_value_to_append(shared_access_policy.access_policy.expiry) + get_value_to_append(canonicalized_resource)) if version: string_to_sign += get_value_to_append(shared_access_policy.id) string_to_sign += get_value_to_append(version, True) else: string_to_sign += get_value_to_append(shared_access_policy.id, True) return self._sign(string_to_sign) def _permission_matches_request(self, shared_access_signature, web_resource, resource_type, required_permission): ''' Check whether requested permission matches given shared_access_signature, web_resource and resource type. ''' required_resource_type = resource_type if required_resource_type == RESOURCE_BLOB: required_resource_type += RESOURCE_CONTAINER for name, value in shared_access_signature.query_string.items(): if name == SIGNED_RESOURCE and \ required_resource_type.find(value) == -1: return False elif name == SIGNED_PERMISSION and \ required_permission.find(value) == -1: return False return web_resource.path.find(shared_access_signature.path) != -1 def _sign(self, string_to_sign): ''' use HMAC-SHA256 to sign the string and convert it as base64 encoded string. ''' return _sign_string(self.account_key, string_to_sign) ================================================ FILE: CustomScript/azure/storage/storageclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os import sys from azure import ( WindowsAzureError, DEV_ACCOUNT_NAME, DEV_ACCOUNT_KEY, _ERROR_STORAGE_MISSING_INFO, ) from azure.http import HTTPError from azure.http.httpclient import _HTTPClient from azure.storage import _storage_error_handler #-------------------------------------------------------------------------- # constants for azure app setting environment variables AZURE_STORAGE_ACCOUNT = 'AZURE_STORAGE_ACCOUNT' AZURE_STORAGE_ACCESS_KEY = 'AZURE_STORAGE_ACCESS_KEY' EMULATED = 'EMULATED' #-------------------------------------------------------------------------- class _StorageClient(object): ''' This is the base class for BlobManager, TableManager and QueueManager. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base='', dev_host=''): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self.account_name = account_name self.account_key = account_key self.requestid = None self.protocol = protocol self.host_base = host_base self.dev_host = dev_host # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. self.use_local_storage = False # check whether it is run in emulator. if EMULATED in os.environ: self.is_emulated = os.environ[EMULATED].lower() != 'false' else: self.is_emulated = False # get account_name and account key. If they are not set when # constructing, get the account and key from environment variables if # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. if not self.account_name or not self.account_key: if self.is_emulated: self.account_name = DEV_ACCOUNT_NAME self.account_key = DEV_ACCOUNT_KEY self.protocol = 'http' self.use_local_storage = True else: self.account_name = os.environ.get(AZURE_STORAGE_ACCOUNT) self.account_key = os.environ.get(AZURE_STORAGE_ACCESS_KEY) if not self.account_name or not self.account_key: raise WindowsAzureError(_ERROR_STORAGE_MISSING_INFO) self._httpclient = _HTTPClient( service_instance=self, account_key=self.account_key, account_name=self.account_name, protocol=self.protocol) self._batchclient = None self._filter = self._perform_request_worker def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = type(self)(self.account_name, self.account_key, self.protocol) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def _get_host(self): if self.use_local_storage: return self.dev_host else: return self.account_name + self.host_base def _perform_request_worker(self, request): return self._httpclient.perform_request(request) def _perform_request(self, request, text_encoding='utf-8'): ''' Sends the request and return response. Catches HTTPError and hand it to error handler ''' try: if self._batchclient is not None: return self._batchclient.insert_request_to_batch(request) else: resp = self._filter(request) if sys.version_info >= (3,) and isinstance(resp, bytes) and \ text_encoding: resp = resp.decode(text_encoding) except HTTPError as ex: _storage_error_handler(ex) return resp ================================================ FILE: CustomScript/azure/storage/tableservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, TABLE_SERVICE_HOST_BASE, DEV_TABLE_HOST, _convert_class_to_xml, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, ) from azure.http import HTTPRequest from azure.http.batchclient import _BatchClient from azure.storage import ( StorageServiceProperties, _convert_entity_to_xml, _convert_response_to_entity, _convert_table_to_xml, _convert_xml_to_entity, _convert_xml_to_table, _sign_storage_table_request, _update_storage_table_header, ) from azure.storage.storageclient import _StorageClient class TableService(_StorageClient): ''' This is the main class managing Table resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=TABLE_SERVICE_HOST_BASE, dev_host=DEV_TABLE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(TableService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def begin_batch(self): if self._batchclient is None: self._batchclient = _BatchClient( service_instance=self, account_key=self.account_key, account_name=self.account_name) return self._batchclient.begin_batch() def commit_batch(self): try: ret = self._batchclient.commit_batch() finally: self._batchclient = None return ret def cancel_batch(self): self._batchclient = None def get_table_service_properties(self): ''' Gets the properties of a storage account's Table service, including Windows Azure Storage Analytics. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def set_table_service_properties(self, storage_service_properties): ''' Sets the properties of a storage account's Table Service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict(response) def query_tables(self, table_name=None, top=None, next_table_name=None): ''' Returns a list of tables under the specified account. table_name: Optional. The specific table to query. top: Optional. Maximum number of tables to return. next_table_name: Optional. When top is used, the next table name is stored in result.x_ms_continuation['NextTableName'] ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() if table_name is not None: uri_part_table_name = "('" + table_name + "')" else: uri_part_table_name = "" request.path = '/Tables' + uri_part_table_name + '' request.query = [ ('$top', _int_or_none(top)), ('NextTableName', _str_or_none(next_table_name)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_table) def create_table(self, table, fail_on_exist=False): ''' Creates a new table in the storage account. table: Name of the table to create. Table name may contain only alphanumeric characters and cannot begin with a numeric character. It is case-insensitive and must be from 3 to 63 characters long. fail_on_exist: Specify whether throw exception when table exists. ''' _validate_not_none('table', table) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/Tables' request.body = _get_request_body(_convert_table_to_xml(table)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_table(self, table_name, fail_not_exist=False): ''' table_name: Name of the table to delete. fail_not_exist: Specify whether throw exception when table doesn't exist. ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/Tables(\'' + _str(table_name) + '\')' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_entity(self, table_name, partition_key, row_key, select=''): ''' Get an entity in a table; includes the $select options. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. select: Property names to select. ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('select', select) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + \ '(PartitionKey=\'' + _str(partition_key) + \ '\',RowKey=\'' + \ _str(row_key) + '\')?$select=' + \ _str(select) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def query_entities(self, table_name, filter=None, select=None, top=None, next_partition_key=None, next_row_key=None): ''' Get entities in a table; includes the $filter and $select options. table_name: Table to query. filter: Optional. Filter as described at http://msdn.microsoft.com/en-us/library/windowsazure/dd894031.aspx select: Optional. Property names to select from the entities. top: Optional. Maximum number of entities to return. next_partition_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextPartitionKey'] next_row_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextRowKey'] ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + '()' request.query = [ ('$filter', _str_or_none(filter)), ('$select', _str_or_none(select)), ('$top', _int_or_none(top)), ('NextPartitionKey', _str_or_none(next_partition_key)), ('NextRowKey', _str_or_none(next_row_key)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_entity) def insert_entity(self, table_name, entity, content_type='application/atom+xml'): ''' Inserts a new entity into a table. table_name: Table name. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(table_name) + '' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def update_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity in a table. The Update Entity operation replaces the entire entity and can be used to remove properties. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity by updating the entity's properties. This operation does not replace the existing entity as the Update Entity operation does. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Can be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def delete_entity(self, table_name, partition_key, row_key, content_type='application/atom+xml', if_match='*'): ''' Deletes an existing entity in a table. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the delete should be performed. To force an unconditional delete, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('content_type', content_type) _validate_not_none('if_match', if_match) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) self._perform_request(request) def insert_or_replace_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Replaces an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def insert_or_merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Merges an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def _perform_request_worker(self, request): auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) return self._httpclient.perform_request(request) ================================================ FILE: CustomScript/customscript.py ================================================ #!/usr/bin/env python # # CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import os.path import re import shutil import subprocess import sys import time import traceback from azure.storage import BlobService from codecs import * from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util import Utils.ScriptUtil as ScriptUtil if sys.version_info[0] == 3: import urllib.request as urllib from urllib.parse import urlparse elif sys.version_info[0] == 2: import urllib2 as urllib from urlparse import urlparse ExtensionShortName = 'CustomScriptForLinux' # Global Variables DownloadDirectory = 'download' # CustomScript-specific Operation DownloadOp = "Download" RunScriptOp = "RunScript" # Change permission of log path ext_log_path = '/var/log/azure/' if os.path.exists(ext_log_path): os.chmod('/var/log/azure/', 0o700) #Main function is the only entrence to this extension handler def main(): #Global Variables definition waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) hutil = None try: for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): dummy_command("Disable", "success", "Disable succeeded") elif re.match("^([-/]*)(uninstall)", a): dummy_command("Uninstall", "success", "Uninstall succeeded") elif re.match("^([-/]*)(install)", a): dummy_command("Install", "success", "Install succeeded") elif re.match("^([-/]*)(enable)", a): hutil = parse_context("Enable") enable(hutil) elif re.match("^([-/]*)(daemon)", a): hutil = parse_context("Executing") daemon(hutil) elif re.match("^([-/]*)(update)", a): dummy_command("Update", "success", "Update succeeded") except Exception as e: err_msg = "Failed with error: {0}, {1}".format(e, traceback.format_exc()) waagent.Error(err_msg) if hutil is not None: hutil.error(err_msg) hutil.do_exit(1, 'Enable','failed','0', 'Enable failed: {0}'.format(err_msg)) def dummy_command(operation, status, msg): hutil = parse_context(operation) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName, console_logger=waagent.LogToConsole, file_logger=waagent.LogToFile) hutil.do_parse_context(operation) return hutil def enable(hutil): """ Ensure the same configuration is executed only once If the previous enable failed, we do not have retry logic here, since the custom script may not work in an intermediate state. """ hutil.exit_if_enabled() start_daemon(hutil) def download_files_with_retry(hutil, retry_count, wait): hutil.log(("Will try to download files, " "number of retries = {0}, " "wait SECONDS between retrievals = {1}s").format(retry_count, wait)) for download_retry_count in range(0, retry_count + 1): try: download_files(hutil) break except Exception as e: error_msg = "{0}, retry = {1}, maxRetry = {2}.".format(e, download_retry_count, retry_count) hutil.error(error_msg) if download_retry_count < retry_count: hutil.log("Sleep {0} seconds".format(wait)) time.sleep(wait) else: waagent.AddExtensionEvent(name=ExtensionShortName, op=DownloadOp, isSuccess=False, version=hutil.get_extension_version(), message="(01100)"+error_msg) raise msg = ("Succeeded to download files, " "retry count = {0}").format(download_retry_count) hutil.log(msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=DownloadOp, isSuccess=True, version=hutil.get_extension_version(), message="(01303)"+msg) return retry_count - download_retry_count def check_idns_with_retry(hutil, retry_count, wait): is_idns_ready = False for check_idns_retry_count in range(0, retry_count + 1): is_idns_ready = check_idns() if is_idns_ready: break else: if check_idns_retry_count < retry_count: hutil.error("Internal DNS is not ready, retry to check.") hutil.log("Sleep {0} seconds".format(wait)) time.sleep(wait) if is_idns_ready: msg = ("Internal DNS is ready, " "retry count = {0}").format(check_idns_retry_count) hutil.log(msg) waagent.AddExtensionEvent(name=ExtensionShortName, op="CheckIDNS", isSuccess=True, version=hutil.get_extension_version(), message="(01306)"+msg) else: error_msg = ("Internal DNS is not ready, " "retry count = {0}, ignore it.").format(check_idns_retry_count) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op="CheckIDNS", isSuccess=False, version=hutil.get_extension_version(), message="(01306)"+error_msg) def check_idns(): ret = waagent.Run("host $(hostname)") return not ret def download_files(hutil): public_settings = hutil.get_public_settings() if public_settings is None: raise ValueError("Public configuration couldn't be None.") cmd = get_command_to_execute(hutil) blob_uris = public_settings.get('fileUris') protected_settings = hutil.get_protected_settings() storage_account_name = None storage_account_key = None if protected_settings: storage_account_name = protected_settings.get("storageAccountName") storage_account_key = protected_settings.get("storageAccountKey") if storage_account_name is not None: storage_account_name = storage_account_name.strip() if storage_account_key is not None: storage_account_key = storage_account_key.strip() if (not blob_uris or not isinstance(blob_uris, list) or len(blob_uris) == 0): error_msg = "fileUris value provided is empty or invalid." hutil.log(error_msg + " Continue with executing command...") waagent.AddExtensionEvent(name=ExtensionShortName, op=DownloadOp, isSuccess=False, version=hutil.get_extension_version(), message="(01001)"+error_msg) return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading files...') if storage_account_name and storage_account_key: hutil.log("Downloading scripts from azure storage...") download_blobs(storage_account_name, storage_account_key, blob_uris, cmd, hutil) elif not(storage_account_name or storage_account_key): hutil.log("No azure storage account and key specified in protected " "settings. Downloading scripts from external links...") download_external_files(blob_uris, cmd, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account and key should appear in pairs." hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=DownloadOp, isSuccess=False, version=hutil.get_extension_version(), message="(01000)"+error_msg) raise ValueError(error_msg) def start_daemon(hutil): cmd = get_command_to_execute(hutil) if cmd: args = [os.path.join(os.getcwd(), "shim.sh"), "-daemon"] # This process will start a new background process by calling # shim.sh -daemon # to run the script and will exit itself immediately. # Redirect stdout and stderr to /dev/null. Otherwise daemon process # will throw Broke pipe exception when parent process exit. devnull = open(os.devnull, 'w') subprocess.Popen(args, stdout=devnull, stderr=devnull) hutil.do_exit(0, 'Enable', 'transitioning', '0', 'Launching the script...') else: error_msg = "commandToExecute is empty or invalid" hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=RunScriptOp, isSuccess=False, version=hutil.get_extension_version(), message="(01002)"+error_msg) raise ValueError(error_msg) def daemon(hutil): retry_count = 10 wait = 20 enable_idns_check = True public_settings = hutil.get_public_settings() if public_settings: if 'retrycount' in public_settings: retry_count = public_settings.get('retrycount') if 'wait' in public_settings: wait = public_settings.get('wait') if 'enableInternalDNSCheck' in public_settings: # removed strtobool/distutils dependency, implementation is based on strtobool specification enable_idns_check_setting = public_settings.get('enableInternalDNSCheck') enable_idns_check = True if ((enable_idns_check_setting.lower() == "yes") | (enable_idns_check_setting.lower() == "y") | (enable_idns_check_setting.lower() == "true") | (enable_idns_check_setting.lower() == "t") | (enable_idns_check_setting.lower() == "on") | (enable_idns_check_setting.lower() == "1")) else False prepare_download_dir(hutil.get_seq_no()) retry_count = download_files_with_retry(hutil, retry_count, wait) # The internal DNS needs some time to be ready. # Wait and retry to check if there is time in retry window. # The check may be removed safely if iDNS is always ready. if enable_idns_check: check_idns_with_retry(hutil, retry_count, wait) cmd = get_command_to_execute(hutil) args = ScriptUtil.parse_args(cmd) if args: ScriptUtil.run_command(hutil, args, prepare_download_dir(hutil.get_seq_no()), 'Daemon', ExtensionShortName, hutil.get_extension_version()) else: error_msg = "commandToExecute is empty or invalid." hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=RunScriptOp, isSuccess=False, version=hutil.get_extension_version(), message="(01002)"+error_msg) raise ValueError(error_msg) def download_blobs(storage_account_name, storage_account_key, blob_uris, command, hutil): for blob_uri in blob_uris: if blob_uri: download_blob(storage_account_name, storage_account_key, blob_uri, command, hutil) def download_blob(storage_account_name, storage_account_key, blob_uri, command, hutil): try: seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) result = download_and_save_blob(storage_account_name, storage_account_key, blob_uri, download_dir, hutil) blob_name, _, _, download_path = result preprocess_files(download_path, hutil) if command and blob_name in command: os.chmod(download_path, 0o100) except Exception as e: error_msg = "Failed to download blob with uri: {0} with error {1}".format(blob_uri, e) raise Exception(error_msg) def download_and_save_blob(storage_account_name, storage_account_key, blob_uri, download_dir, hutil): container_name = get_container_name_from_uri(blob_uri, hutil) blob_name = get_blob_name_from_uri(blob_uri, hutil) host_base = get_host_base_from_uri(blob_uri) # If blob_name is a path, extract the file_name last_sep = blob_name.rfind('/') if last_sep != -1: file_name = blob_name[last_sep+1:] else: file_name = blob_name download_path = os.path.join(download_dir, file_name) # Guest agent already ensure the plugin is enabled one after another. # The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key, host_base=host_base) blob_service.get_blob_to_path(container_name, blob_name, download_path) return blob_name, container_name, host_base, download_path def download_external_files(uris, command, hutil): for uri in uris: if uri: download_external_file(uri, command, hutil) def download_external_file(uri, command, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) path = get_path_from_uri(uri) file_name = path.split('/')[-1] file_path = os.path.join(download_dir, file_name) try: download_and_save_file(uri, file_path) preprocess_files(file_path, hutil) if command and file_name in command: os.chmod(file_path, 0o100) except Exception as e: error_msg = ("Failed to download external file with uri: {0} " "with error {1}").format(uri, e) raise Exception(error_msg) def download_and_save_file(uri, file_path, timeout=30, buf_size=1024): src = urllib.urlopen(uri, timeout=timeout) with open(file_path, 'wb') as dest: buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def preprocess_files(file_path, hutil): """ The file is preprocessed if it satisfies any of the following condistions: the file's extension is '.sh' or '.py' the content of the file starts with '#!' """ ret = to_process(file_path) if ret: dos2unix(file_path) hutil.log("Converting {0} from DOS to Unix formats: Done".format(file_path)) remove_bom(file_path) hutil.log("Removing BOM of {0}: Done".format(file_path)) def to_process(file_path, extensions=['.sh', ".py"]): for extension in extensions: if file_path.endswith(extension): return True with open(file_path, 'rb') as f: contents = f.read(64) if b'#!' in contents: return True return False def dos2unix(file_path): with open(file_path, 'rU') as f: contents = f.read() temp_file_path = file_path + ".tmp" with open(temp_file_path, 'wb') as f_temp: f_temp.write(contents.encode()) shutil.move(temp_file_path, file_path) def remove_bom(file_path): with open(file_path, 'rb') as f: contents = f.read() bom_list = [BOM, BOM_BE, BOM_LE, BOM_UTF16, BOM_UTF16_BE, BOM_UTF16_LE, BOM_UTF8] for bom in bom_list: if contents.startswith(bom): break else: return new_contents = None for encoding in ["utf-8-sig", "utf-16"]: try: new_contents = contents.decode(encoding).encode('utf-8') break except UnicodeDecodeError: continue if new_contents is not None: temp_file_path = file_path + ".tmp" with open(temp_file_path, 'wb') as f_temp: f_temp.write(new_contents) shutil.move(temp_file_path, file_path) def get_blob_name_from_uri(uri, hutil): return get_properties_from_uri(uri, hutil)['blob_name'] def get_container_name_from_uri(uri, hutil): return get_properties_from_uri(uri, hutil)['container_name'] def get_host_base_from_uri(blob_uri): uri = urlparse(blob_uri) netloc = uri.netloc if netloc is None: return None return netloc[netloc.find('.'):] def get_properties_from_uri(uri, hutil): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.error("Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def get_path_from_uri(uriStr): uri = urlparse(uriStr) return uri.path def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_command_to_execute(hutil): public_settings = hutil.get_public_settings() protected_settings = hutil.get_protected_settings() cmd_public = public_settings.get('commandToExecute') cmd_protected = None if protected_settings is not None: cmd_protected = protected_settings.get('commandToExecute') if cmd_public and cmd_protected: err_msg = ("commandToExecute was specified both in public settings " "and protected settings. It can only be specified in one of them.") hutil.error(err_msg) hutil.do_exit(1, 'Enable','failed','0', 'Enable failed: {0}'.format(err_msg)) if cmd_public: hutil.log("Command to execute:" + cmd_public) return cmd_public else: return cmd_protected if __name__ == '__main__' : main() ================================================ FILE: CustomScript/manifest.xml ================================================ <?xml version='1.0' encoding='utf-8' ?> <ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure"> <ProviderNameSpace>Microsoft.OSTCExtensions</ProviderNameSpace> <Type>CustomScriptForLinux</Type> <Version>1.5.5</Version> <Label>Microsoft Azure Custom Script Extension for Linux Virtual Machines</Label> <HostingResources>VmRole</HostingResources> <MediaLink></MediaLink> <Description>Please consider using Microsoft.Azure.Extensions.CustomScript instead.</Description> <IsInternalExtension>true</IsInternalExtension> <Eula>https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt</Eula> <PrivacyUri>http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx</PrivacyUri> <HomepageUri>https://github.com/Azure/azure-linux-extensions</HomepageUri> <IsJsonExtension>true</IsJsonExtension> <SupportedOS>Linux</SupportedOS> <CompanyName>Microsoft</CompanyName> <!--%REGIONS%--> </ExtensionImage> ================================================ FILE: CustomScript/references ================================================ Utils/ ================================================ FILE: CustomScript/shim.sh ================================================ #!/usr/bin/env bash # The shim scripts provide a single entry point for CSE and will invoke the customscript.py entry point using the # appropriate python interpreter version. # Arguments passed to the shim layer are redirected to the invoked script without any validation. COMMAND="./customscript.py" PYTHON="" ARG="$@" function find_python(){ local python_exec_command=$1 # Check if there is python defined. if command -v python >/dev/null 2>&1 ; then eval ${python_exec_command}="python" else # Python was not found. Searching for Python3 now. if command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" fi fi } find_python PYTHON if [ -z "$PYTHON" ] then echo "No Python interpreter found on the box" >&2 exit 51 # Not Supported else echo "Found: `${PYTHON} --version`" fi ${PYTHON} ${COMMAND} ${ARG} exit $? # DONE ================================================ FILE: CustomScript/test/HandlerEnvironment.json ================================================ [{ "name": "Microsoft.OSTCExtensions.CustomScriptForLinuxTest", "seqNo": "0", "version": 1.0, "handlerEnvironment": { "logFolder": "/root/CustomScriptForLinuxTest", "configFolder": "/root/CustomScriptForLinuxTest/config", "statusFolder": "/root/CustomScriptForLinuxTest/status", "heartbeatFile": "/root/CustomScriptForLinuxTest/heartbeat.log"}}] ================================================ FILE: CustomScript/test/MockUtil.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class MockUtil: def __init__(self, test): self.test = test def get_log_dir(self): return "/tmp" def log(self, msg): print(msg) def error(self, msg): print(msg) def get_seq_no(self): return "0" def do_status_report(self, operation, status, status_code, message): self.test.assertNotEqual(None, message) def do_exit(self,exit_code,operation,status,code,message): self.test.assertNotEqual(None, message) ================================================ FILE: CustomScript/test/create_test_blob.py ================================================ import blob import blob_mooncake import customscript as cs from azure.storage import BlobService def create_blob(blob, txt): uri = blob.uri host_base = cs.get_host_base_from_uri(uri) service = BlobService(blob.name, blob.key, host_base=host_base) container_name = cs.get_container_name_from_uri(uri) blob_name = cs.get_blob_name_from_uri(uri) service.put_block_blob_from_text(container_name, blob_name, txt) if __name__ == "__main__": create_blob(blob, "public azure\n") create_blob(blob_mooncake, "mooncake\n") ================================================ FILE: CustomScript/test/env.py ================================================ #!/usr/bin/env python # # CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root) ================================================ FILE: CustomScript/test/run_all.sh ================================================ #!/bin/bash # # This script is used to set up a test env for extensions # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # script=$(dirname $0) root=$script cd $root root=`pwd` echo "Run unit test:" ls test_*.py ls test_*.py | sed -e 's/\.py//'|xargs python -m unittest ================================================ FILE: CustomScript/test/test_blob_download.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import customscript as cs import blob as test_blob import blob_mooncake as test_blob_mooncake class TestBlobDownload(unittest.TestCase): def test_download_blob(self): download_dir = "/tmp" cs.download_and_save_blob(test_blob.name, test_blob.key, test_blob.uri, download_dir) cs.download_and_save_blob(test_blob_mooncake.name, test_blob_mooncake.key, test_blob_mooncake.uri, download_dir) if __name__ == '__main__': unittest.main() ================================================ FILE: CustomScript/test/test_file_download.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import os import tempfile import customscript as cs class TestFileDownload(unittest.TestCase): def test_download_blob(self): pass def download_to_tmp(self, uri): tmpFile = tempfile.TemporaryFile() file_path = os.path.abspath(tmpFile.name) cs.download_and_save_file(uri, file_path) file_size = os.path.getsize(file_path) self.assertNotEqual(file_size, 0) tmpFile.close() os.unlink(tmpFile.name) def test_download_bin_file(self): uri = "http://www.bing.com/rms/Homepage$HPBottomBrand_default/ic/1f76acf2/d3a8cfeb.png" self.download_to_tmp(uri) def test_download_text_file(self): uri = "http://www.bing.com/" self.download_to_tmp(uri) if __name__ == '__main__': unittest.main() ================================================ FILE: CustomScript/test/test_preprocess_file.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import os import zipfile import codecs import shutil from MockUtil import MockUtil import customscript as cs class TestPreprocessFile(unittest.TestCase): @classmethod def setUpClass(cls): try: os.remove('master.zip') shutil.rmtree('encoding') except: pass os.system('wget https://github.com/bingosummer/scripts/archive/master.zip') zipFile = zipfile.ZipFile('master.zip') zipFile.extractall() zipFile.close() shutil.move('scripts-master', 'encoding') def test_bin_file(self): print("\nTest: Is it a binary file") file_path = "encoding/mslogo.png" self.assertFalse(cs.is_text_file(file_path)[0]) def test_text_file(self): print("\nTest: Is it a text file") files = [file for file in os.listdir('encoding') if file.endswith('py') or file.endswith('sh') or file.endswith('txt')] for file in files: file_path = os.path.join('encoding', file) try: self.assertTrue(cs.is_text_file(file_path)[0]) except: print(file) raise def test_bom(self): print("\nTest: Remove BOM") hutil = MockUtil(self) files = [file for file in os.listdir('encoding') if 'bom' in file] for file in files: file_path = os.path.join('encoding', file) cs.preprocess_files(file_path, hutil) with open(file_path, 'r') as f: contents = f.read() if "utf8" in file: self.assertFalse(contents.startswith(codecs.BOM_UTF8)) if "utf16_le" in file: self.assertFalse(contents.startswith(codecs.BOM_LE)) if "utf16_be" in file: self.assertFalse(contents.startswith(codecs.BOM_BE)) def test_windows_line_break(self): print("\nTest: Convert text files from DOS to Unix formats") hutil = MockUtil(self) files = [file for file in os.listdir('encoding') if 'dos' in file] for file in files: file_path = os.path.join('encoding', file) cs.preprocess_files(file_path, hutil) with open(file_path, 'r') as f: contents = f.read() self.assertFalse("\r\n" in contents) if __name__ == '__main__': unittest.main() ================================================ FILE: CustomScript/test/test_uri_utils.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import customscript as cs class TestUriUtils(unittest.TestCase): def test_get_path_from_uri(self): uri = "http://qingfu2.blob.core.windows.net/vhds/abc.sh?st=2014-06-27Z&se=2014-06-27&sr=c&sp=r&sig=KBwcWOx" path = cs.get_path_from_uri(uri) self.assertEqual(path, "/vhds/abc.sh") def test_get_blob_name_from_uri(self): uri = "http://qingfu2.blob.core.windows.net/vhds/abc.sh?st=2014-06-27Z&se=2014-06-27&sr=c&sp=r&sig=KBwcWOx" blob = cs.get_blob_name_from_uri(uri) self.assertEqual(blob, "abc.sh") def test_get_container_name_from_uri(self): uri = "http://qingfu2.blob.core.windows.net/vhds/abc.sh?st=2014-06-27Z&se=2014-06-27&sr=c&sp=r&sig=KBwcWOx" container = cs.get_container_name_from_uri(uri) self.assertEqual(container, "vhds") def test_get_host_base_from_uri(self): blob_uri = "http://qingfu2.blob.core.windows.net/vhds/abc.sh?st=2014-06-27Z&se=2014-06-27&sr=c&sp=r&sig=KBwcWOx" host_base = cs.get_host_base_from_uri(blob_uri) self.assertEqual(host_base, ".blob.core.windows.net") blob_uri = "https://yue.blob.core.chinacloudapi.cn/" host_base = cs.get_host_base_from_uri(blob_uri) self.assertEqual(host_base, ".blob.core.chinacloudapi.cn") if __name__ == '__main__': unittest.main() ================================================ FILE: CustomScript/test/timeout.sh ================================================ #!/bin/bash for i in $(seq 1500) do echo `date` + The script is running... >&2 echo `date` + ERROR:The script is running... sleep 1 done ================================================ FILE: DSC/HandlerManifest.json ================================================ [ { "name" : "DSCForLinux", "version": 1.0, "handlerManifest": { "disableCommand": "./extension_shim.sh -c ./dsc.py -d", "enableCommand": "./extension_shim.sh -c ./dsc.py -e", "installCommand": "./extension_shim.sh -c ./dsc.py -i", "uninstallCommand": "./extension_shim.sh -c ./dsc.py -u", "updateCommand": "./extension_shim.sh -c ./dsc.py -p", "rebootAfterInstall": false, "reportHeartbeat": false } } ] ================================================ FILE: DSC/Makefile ================================================ all: package SOURCES = \ httpclientfactory.py \ subprocessfactory.py \ curlhttpclient.py \ serializerfactory.py \ httpclient.py \ urllib2httpclient.py \ urllib3httpclient.py \ dsc.py \ test \ HandlerManifest.json \ manifest.xml \ azure \ packages \ ../Utils \ ../Common/WALinuxAgent-2.0.16/waagent clean: rm -rf output package: $(SOURCES) mkdir -p output cp -t output -r $(SOURCES) cd output && zip -r ../DSC.zip * > /dev/null .PHONY: all clean package ================================================ FILE: DSC/README.md ================================================ # DSCForLinux Extension Allow the owner of the Azure Virtual Machines to configure the VM using Desired State Configuration (DSC) for Linux. Latest version is 2.71 About how to create MOF document, please refer to below documents. * [Get started with Desired State Configuration (DSC) for Linux](https://technet.microsoft.com/en-us/library/mt126211.aspx) * [Built-In Desired State Configuration Resources for Linux](https://msdn.microsoft.com/en-us/powershell/dsc/lnxbuiltinresources) * [DSC for Linux releases] (https://github.com/Microsoft/PowerShell-DSC-for-Linux/releases) DSCForLinux Extension can: * Register the Linux VM to Azure Automation account in order to pull configurations from Azure Automation service (Register ExtensionAction) * Push MOF configurations to the Linux VM (Push ExtensionAction) * Applies Meta MOF configuration to the Linux VM to configure Pull Server in order to pull Node Configuration (Pull ExtensionAction) * Install custom DSC modules to the Linux VM (Install ExtensionAction) * Remove custom DSC modules to the Linux VM (Remove ExtensionAction) # User Guide ## 1. Configuration schema ### 1.1. Public configuration Here're all the supported public configuration parameters: * `FileUri`: (optional, string) the uri of the MOF file/Meta MOF file/custom resource ZIP file. * `ResourceName`: (optional, string) the name of the custom resource module * `ExtensionAction`: (optional, string) Specifies what an extension does. valid values: Register, Push, Pull, Install, Remove. If not specified, it's considered as Push Action by default. * `NodeConfigurationName`: (optional, string) the name of a node configuration to apply. * `RefreshFrequencyMins`: (optional, int) Specifies how often (in minutes) DSC attempts to obtain the configuration from the pull server. If configuration on the pull server differs from the current one on the target node, it is copied to the pending store and applied. * `ConfigurationMode`: (optional, string) Specifies how DSC should apply the configuration. Valid values are: ApplyOnly, ApplyAndMonitor, ApplyAndAutoCorrect. * `ConfigurationModeFrequencyMins`: (optional, int) Specifies how often (in minutes) DSC ensures that the configuration is in the desired state. > **NOTE:** If you are using a version < 2.3, mode parameter is same as ExtensionAction. Mode seems to be a overloaded term. Therefore to avoid the confusion, ExtensionAction is being used from 2.3 version onwards. For backward compatibility, the extension supports both mode and ExtensionAction. ### 1.2 Protected configuration Here're all the supported protected configuration parameters: * `StorageAccountName`: (optional, string) the name of the storage account that contains the file * `StorageAccountKey`: (optional, string) the key of the storage account that contains the file * `RegistrationUrl`: (optional, string) the URL of the Azure Automation account * `RegistrationKey`: (optional, string) the access key of the Azure Automation account ## 2. Deploying the Extension to a VM You can deploy it using Azure CLI, Azure PowerShell and ARM template. ### 2.1. Using [**Azure CLI**][azure-cli] Before deploying DSCForLinux Extension, you should configure your `public.json` and `protected.json`, according to the different scenarios in section 3. #### 2.1.1. Classic The Classic mode is also called Azure Service Management mode. You can switch to it by running: ``` $ azure config mode asm ``` You can deploy DSCForLinux Extension by running: ``` $ azure vm extension set <vm-name> DSCForLinux Microsoft.OSTCExtensions <version> \ --private-config-path protected.json --public-config-path public.json ``` To learn the latest extension version available, run: ``` $ azure vm extension list ``` #### 2.1.2. Resource Manager You can switch to Azure Resource Manager mode by running: ``` $ azure config mode arm ``` You can deploy DSCForLinux Extension by running: ``` $ azure vm extension set <resource-group> <vm-name> \ DSCForLinux Microsoft.OSTCExtensions <version> \ --private-config-path protected.json --public-config-path public.json ``` > **NOTE:** In ARM mode, `azure vm extension list` is not available for now. ### 2.2. Using [**Azure PowerShell**][azure-powershell] #### 2.2.1 Classic You can login to your Azure account (Azure Service Management mode) by running: ```powershell Add-AzureAccount ``` And deploy DSCForLinux Extension by running: ```powershell $vmname = '<vm-name>' $vm = Get-AzureVM -ServiceName $vmname -Name $vmname $extensionName = 'DSCForLinux' $publisher = 'Microsoft.OSTCExtensions' $version = '<version>' # You need to change the content of the $privateConfig and $publicConfig # according to different scenarios in section 3 $privateConfig = '{ "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" }' $publicConfig = '{ "ExtensionAction": "Push", "FileUri": "<mof-file-uri>" }' Set-AzureVMExtension -ExtensionName $extensionName -VM $vm -Publisher $publisher ` -Version $version -PrivateConfiguration $privateConfig ` -PublicConfiguration $publicConfig | Update-AzureVM ``` #### 2.2.2.Resource Manager You can login to your Azure account (Azure Resource Manager mode) by running: ```powershell Login-AzureRmAccount ``` Click [**HERE**](https://azure.microsoft.com/en-us/documentation/articles/powershell-azure-resource-manager/) to learn more about how to use Azure PowerShell with Azure Resource Manager. You can deploy DSCForLinux Extension by running: ```powershell $rgName = '<resource-group-name>' $vmName = '<vm-name>' $location = '<location>' $extensionName = 'DSCForLinux' $publisher = 'Microsoft.OSTCExtensions' $version = '<version>' # You need to change the content of the $privateConfig and $publicConfig # according to different scenarios in section 3 $privateConfig = '{ "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" }' $publicConfig = '{ "ExtensionAction": "Push", "FileUri": "<mof-file-uri>" }' Set-AzureRmVMExtension -ResourceGroupName $rgName -VMName $vmName -Location $location ` -Name $extensionName -Publisher $publisher -ExtensionType $extensionName ` -TypeHandlerVersion $version -SettingString $publicConfig -ProtectedSettingString $privateConfig ``` ### 2.3. Using [**ARM Template**][arm-template] The sample ARM template is [201-dsc-linux-azure-storage-on-ubuntu](https://github.com/Azure/azure-quickstart-templates/tree/master/201-dsc-linux-azure-storage-on-ubuntu) and [201-dsc-linux-public-storage-on-ubuntu](https://github.com/Azure/azure-quickstart-templates/tree/master/201-dsc-linux-public-storage-on-ubuntu). For more details about ARM template, please visit [Authoring Azure Resource Manager templates](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authoring-templates/). ## 3. Scenarios ### 3.1 Register to Azure Automation account protected.json ```json { "RegistrationUrl": "<azure-automation-account-url>", "RegistrationKey": "<azure-automation-account-key>" } ``` public.json ```json { "ExtensionAction" : "Register", "NodeConfigurationName" : "<node-configuration-name>", "RefreshFrequencyMins" : "<value>", "ConfigurationMode" : "<ApplyAndMonitor | ApplyAndAutoCorrect | ApplyOnly>", "ConfigurationModeFrequencyMins" : "<value>" } ``` powershell format ```powershell $privateConfig = '{ "RegistrationUrl": "<azure-automation-account-url>", "RegistrationKey": "<azure-automation-account-key>" }' $publicConfig = '{ "ExtensionAction" : "Register", "NodeConfigurationName": "<node-configuration-name>", "RefreshFrequencyMins": "<value>", "ConfigurationMode": "<ApplyAndMonitor | ApplyAndAutoCorrect | ApplyOnly>", "ConfigurationModeFrequencyMins": "<value>" }' ``` ### 3.2 Apply a MOF configuration file (in Azure Storage Account) to the VM protected.json ```json { "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" } ``` public.json ```json { "FileUri": "<mof-file-uri>", "ExtensionAction": "Push" } ``` powershell format ```powershell $privateConfig = '{ "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" }' $publicConfig = '{ "FileUri": "<mof-file-uri>", "ExtensionAction": "Push" }' ``` ### 3.3. Apply a MOF configuration file (in public storage) to the VM public.json ```json { "FileUri": "<mof-file-uri>" } ``` powershell format ```powershell $publicConfig = '{ "FileUri": "<mof-file-uri>" }' ``` ### 3.4. Apply a meta MOF configuration file (in Azure Storage Account) to the VM protected.json ```json { "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" } ``` public.json ```json { "ExtensionAction": "Pull", "FileUri": "<meta-mof-file-uri>" } ``` powershell format ```powershell $privateConfig = '{ "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" }' $publicConfig = '{ "ExtensionAction": "Pull", "FileUri": "<meta-mof-file-uri>" }' ``` ### 3.5. Apply a meta MOF configuration file (in public storage) to the VM public.json ```json { "FileUri": "<meta-mof-file-uri>", "ExtensionAction": "Pull" } ``` powershell format ```powershell $publicConfig = '{ "FileUri": "<meta-mof-file-uri>", "ExtensionAction": "Pull" }' ``` ### 3.6. Install a custom resource module (ZIP file in Azure Storage Account) to the VM protected.json ```json { "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" } ``` public.json ```json { "ExtensionAction": "Install", "FileUri": "<resource-zip-file-uri>" } ``` powershell format ```powershell $privateConfig = '{ "StorageAccountName": "<storage-account-name>", "StorageAccountKey": "<storage-account-key>" }' $publicConfig = '{ "ExtensionAction": "Install", "FileUri": "<resource-zip-file-uri>" }' ``` ### 3.7. Install a custom resource module (ZIP file in public storage) to the VM public.json ```json { "ExtensionAction": "Install", "FileUri": "<resource-zip-file-uri>" } ``` powershell format ```powershell $publicConfig = '{ "ExtensionAction": "Install", "FileUri": "<resource-zip-file-uri>" }' ``` ### 3.8. Remove a custom resource module from the VM public.json ```json { "ResourceName": "<resource-name>", "ExtensionAction": "Remove" } ``` powershell format ```powershell $publicConfig = '{ "ResourceName": "<resource-name>", "ExtensionAction": "Remove" }' ``` ## 4. Supported Linux Distributions - Ubuntu 14.04 LTS, 16.04 LTS, 18.04 LTS and 20.04 LTS - Debian 8, 9 and 10 - Oracle Linux 6 and 7 - CentOS 6, 7 and 8 - RHEL 6, 7 and 8 - SUSE Linux Enterprise Server 12 and 15 ## 5. Debug * The status of the extension is reported back to Azure so that user can see the status on Azure Portal * The operation log of the extension is `/var/log/azure/<extension-name>/<version>/extension.log` file. ## 6. Known issue * To distribute MOF configurations to the Linux VM with Pull Servers, you need to make sure the cron service is running in the VM. ## Changelog ``` # 2.5 (2017-05-25) - Added support Oracle Distros # 2.4 (2017-05-14) - Added more logging # 2.3 (2017-05-08) - Update to OMI v1.1.0-8 and Linux DSC v1.1.1-294 - Added optional public.json parmeters: 'NodeConfigurationName', 'RefreshFrequencyMins', 'ConfigurationMode' and 'ConfigurationModeFrequencyMins'. - Added a new parameter 'ExtensionAction' to replace 'mode' to avoid confusion with DSC terminology: push/pull mode. - Supports mode parameter for backward compatibility. # 2.0 (2016-03-10) - Pick up Linux DSC v1.1.1 - Add function to register Azure Automation - Refine extension configurations # 1.0 (2015-09-24) - Initial version ``` [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: DSC/azure/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import base64 import hashlib import hmac import sys import types import warnings import inspect if sys.version_info < (3,): from urllib2 import quote as url_quote from urllib2 import unquote as url_unquote _strtype = basestring else: from urllib.parse import quote as url_quote from urllib.parse import unquote as url_unquote _strtype = str from datetime import datetime from xml.dom import minidom from xml.sax.saxutils import escape as xml_escape #-------------------------------------------------------------------------- # constants __author__ = 'Microsoft Corp. <ptvshelp@microsoft.com>' __version__ = '0.8.4' # Live ServiceClient URLs BLOB_SERVICE_HOST_BASE = '.blob.core.windows.net' QUEUE_SERVICE_HOST_BASE = '.queue.core.windows.net' TABLE_SERVICE_HOST_BASE = '.table.core.windows.net' SERVICE_BUS_HOST_BASE = '.servicebus.windows.net' MANAGEMENT_HOST = 'management.core.windows.net' # Development ServiceClient URLs DEV_BLOB_HOST = '127.0.0.1:10000' DEV_QUEUE_HOST = '127.0.0.1:10001' DEV_TABLE_HOST = '127.0.0.1:10002' # Default credentials for Development Storage Service DEV_ACCOUNT_NAME = 'devstoreaccount1' DEV_ACCOUNT_KEY = 'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==' # All of our error messages _ERROR_CANNOT_FIND_PARTITION_KEY = 'Cannot find partition key in request.' _ERROR_CANNOT_FIND_ROW_KEY = 'Cannot find row key in request.' _ERROR_INCORRECT_TABLE_IN_BATCH = \ 'Table should be the same in a batch operations' _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH = \ 'Partition Key should be the same in a batch operations' _ERROR_DUPLICATE_ROW_KEY_IN_BATCH = \ 'Row Keys should not be the same in a batch operations' _ERROR_BATCH_COMMIT_FAIL = 'Batch Commit Fail' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE = \ 'Message is not peek locked and cannot be deleted.' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK = \ 'Message is not peek locked and cannot be unlocked.' _ERROR_QUEUE_NOT_FOUND = 'Queue was not found' _ERROR_TOPIC_NOT_FOUND = 'Topic was not found' _ERROR_CONFLICT = 'Conflict ({0})' _ERROR_NOT_FOUND = 'Not found ({0})' _ERROR_UNKNOWN = 'Unknown error ({0})' _ERROR_SERVICEBUS_MISSING_INFO = \ 'You need to provide servicebus namespace, access key and Issuer' _ERROR_STORAGE_MISSING_INFO = \ 'You need to provide both account name and access key' _ERROR_ACCESS_POLICY = \ 'share_access_policy must be either SignedIdentifier or AccessPolicy ' + \ 'instance' _WARNING_VALUE_SHOULD_BE_BYTES = \ 'Warning: {0} must be bytes data type. It will be converted ' + \ 'automatically, with utf-8 text encoding.' _ERROR_VALUE_SHOULD_BE_BYTES = '{0} should be of type bytes.' _ERROR_VALUE_NONE = '{0} should not be None.' _ERROR_VALUE_NEGATIVE = '{0} should not be negative.' _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY = \ 'Cannot serialize the specified value ({0}) to an entity. Please use ' + \ 'an EntityProperty (which can specify custom types), int, str, bool, ' + \ 'or datetime.' _ERROR_PAGE_BLOB_SIZE_ALIGNMENT = \ 'Invalid page blob size: {0}. ' + \ 'The size must be aligned to a 512-byte boundary.' _USER_AGENT_STRING = 'pyazure/' + __version__ METADATA_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices/metadata' class WindowsAzureData(object): ''' This is the base of data class. It is only used to check whether it is instance or not. ''' pass class WindowsAzureError(Exception): ''' WindowsAzure Excpetion base class. ''' def __init__(self, message): super(WindowsAzureError, self).__init__(message) class WindowsAzureConflictError(WindowsAzureError): '''Indicates that the resource could not be created because it already exists''' def __init__(self, message): super(WindowsAzureConflictError, self).__init__(message) class WindowsAzureMissingResourceError(WindowsAzureError): '''Indicates that a request for a request for a resource (queue, table, container, etc...) failed because the specified resource does not exist''' def __init__(self, message): super(WindowsAzureMissingResourceError, self).__init__(message) class WindowsAzureBatchOperationError(WindowsAzureError): '''Indicates that a batch operation failed''' def __init__(self, message, code): super(WindowsAzureBatchOperationError, self).__init__(message) self.code = code class Feed(object): pass class _Base64String(str): pass class HeaderDict(dict): def __getitem__(self, index): return super(HeaderDict, self).__getitem__(index.lower()) def _encode_base64(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') encoded = base64.b64encode(data) return encoded.decode('utf-8') def _decode_base64_to_bytes(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') return base64.b64decode(data) def _decode_base64_to_text(data): decoded_bytes = _decode_base64_to_bytes(data) return decoded_bytes.decode('utf-8') def _get_readable_id(id_name, id_prefix_to_skip): """simplified an id to be more friendly for us people""" # id_name is in the form 'https://namespace.host.suffix/name' # where name may contain a forward slash! pos = id_name.find('//') if pos != -1: pos += 2 if id_prefix_to_skip: pos = id_name.find(id_prefix_to_skip, pos) if pos != -1: pos += len(id_prefix_to_skip) pos = id_name.find('/', pos) if pos != -1: return id_name[pos + 1:] return id_name def _get_entry_properties_from_node(entry, include_id, id_prefix_to_skip=None, use_title_as_id=False): ''' get properties from entry xml ''' properties = {} etag = entry.getAttributeNS(METADATA_NS, 'etag') if etag: properties['etag'] = etag for updated in _get_child_nodes(entry, 'updated'): properties['updated'] = updated.firstChild.nodeValue for name in _get_children_from_path(entry, 'author', 'name'): if name.firstChild is not None: properties['author'] = name.firstChild.nodeValue if include_id: if use_title_as_id: for title in _get_child_nodes(entry, 'title'): properties['name'] = title.firstChild.nodeValue else: for id in _get_child_nodes(entry, 'id'): properties['name'] = _get_readable_id( id.firstChild.nodeValue, id_prefix_to_skip) return properties def _get_entry_properties(xmlstr, include_id, id_prefix_to_skip=None): ''' get properties from entry xml ''' xmldoc = minidom.parseString(xmlstr) properties = {} for entry in _get_child_nodes(xmldoc, 'entry'): properties.update(_get_entry_properties_from_node(entry, include_id, id_prefix_to_skip)) return properties def _get_first_child_node_value(parent_node, node_name): xml_attrs = _get_child_nodes(parent_node, node_name) if xml_attrs: xml_attr = xml_attrs[0] if xml_attr.firstChild: value = xml_attr.firstChild.nodeValue return value def _get_child_nodes(node, tagName): return [childNode for childNode in node.getElementsByTagName(tagName) if childNode.parentNode == node] def _get_children_from_path(node, *path): '''descends through a hierarchy of nodes returning the list of children at the inner most level. Only returns children who share a common parent, not cousins.''' cur = node for index, child in enumerate(path): if isinstance(child, _strtype): next = _get_child_nodes(cur, child) else: next = _get_child_nodesNS(cur, *child) if index == len(path) - 1: return next elif not next: break cur = next[0] return [] def _get_child_nodesNS(node, ns, tagName): return [childNode for childNode in node.getElementsByTagNameNS(ns, tagName) if childNode.parentNode == node] def _create_entry(entry_body): ''' Adds common part of entry to a given entry body and return the whole xml. ''' updated_str = datetime.utcnow().isoformat() if datetime.utcnow().utcoffset() is None: updated_str += '+00:00' entry_start = '''<?xml version="1.0" encoding="utf-8" standalone="yes"?> <entry xmlns:d="http://schemas.microsoft.com/ado/2007/08/dataservices" xmlns:m="http://schemas.microsoft.com/ado/2007/08/dataservices/metadata" xmlns="http://www.w3.org/2005/Atom" > <title /><updated>{updated}</updated><author><name /></author><id /> <content type="application/xml"> {body}</content></entry>''' return entry_start.format(updated=updated_str, body=entry_body) def _to_datetime(strtime): return datetime.strptime(strtime, "%Y-%m-%dT%H:%M:%S.%f") _KNOWN_SERIALIZATION_XFORMS = { 'include_apis': 'IncludeAPIs', 'message_id': 'MessageId', 'content_md5': 'Content-MD5', 'last_modified': 'Last-Modified', 'cache_control': 'Cache-Control', 'account_admin_live_email_id': 'AccountAdminLiveEmailId', 'service_admin_live_email_id': 'ServiceAdminLiveEmailId', 'subscription_id': 'SubscriptionID', 'fqdn': 'FQDN', 'private_id': 'PrivateID', 'os_virtual_hard_disk': 'OSVirtualHardDisk', 'logical_disk_size_in_gb': 'LogicalDiskSizeInGB', 'logical_size_in_gb': 'LogicalSizeInGB', 'os': 'OS', 'persistent_vm_downtime_info': 'PersistentVMDowntimeInfo', 'copy_id': 'CopyId', } def _get_serialization_name(element_name): """converts a Python name into a serializable name""" known = _KNOWN_SERIALIZATION_XFORMS.get(element_name) if known is not None: return known if element_name.startswith('x_ms_'): return element_name.replace('_', '-') if element_name.endswith('_id'): element_name = element_name.replace('_id', 'ID') for name in ['content_', 'last_modified', 'if_', 'cache_control']: if element_name.startswith(name): element_name = element_name.replace('_', '-_') return ''.join(name.capitalize() for name in element_name.split('_')) if sys.version_info < (3,): _unicode_type = unicode def _str(value): if isinstance(value, unicode): return value.encode('utf-8') return str(value) else: _str = str _unicode_type = str def _str_or_none(value): if value is None: return None return _str(value) def _int_or_none(value): if value is None: return None return str(int(value)) def _bool_or_none(value): if value is None: return None if isinstance(value, bool): if value: return 'true' else: return 'false' return str(value) def _convert_class_to_xml(source, xml_prefix=True): if source is None: return '' xmlstr = '' if xml_prefix: xmlstr = '<?xml version="1.0" encoding="utf-8"?>' if isinstance(source, list): for value in source: xmlstr += _convert_class_to_xml(value, False) elif isinstance(source, WindowsAzureData): class_name = source.__class__.__name__ xmlstr += '<' + class_name + '>' for name, value in vars(source).items(): if value is not None: if isinstance(value, list) or \ isinstance(value, WindowsAzureData): xmlstr += _convert_class_to_xml(value, False) else: xmlstr += ('<' + _get_serialization_name(name) + '>' + xml_escape(str(value)) + '</' + _get_serialization_name(name) + '>') xmlstr += '</' + class_name + '>' return xmlstr def _find_namespaces_from_child(parent, child, namespaces): """Recursively searches from the parent to the child, gathering all the applicable namespaces along the way""" for cur_child in parent.childNodes: if cur_child is child: return True if _find_namespaces_from_child(cur_child, child, namespaces): # we are the parent node for key in cur_child.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': namespaces[key] = cur_child.attributes[key] break return False def _find_namespaces(parent, child): res = {} for key in parent.documentElement.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': res[key] = parent.documentElement.attributes[key] _find_namespaces_from_child(parent, child, res) return res def _clone_node_with_namespaces(node_to_clone, original_doc): clone = node_to_clone.cloneNode(True) for key, value in _find_namespaces(original_doc, node_to_clone).items(): clone.attributes[key] = value return clone def _convert_response_to_feeds(response, convert_callback): if response is None: return None feeds = _list_of(Feed) x_ms_continuation = HeaderDict() for name, value in response.headers: if 'x-ms-continuation' in name: x_ms_continuation[name[len('x-ms-continuation') + 1:]] = value if x_ms_continuation: setattr(feeds, 'x_ms_continuation', x_ms_continuation) xmldoc = minidom.parseString(response.body) xml_entries = _get_children_from_path(xmldoc, 'feed', 'entry') if not xml_entries: # in some cases, response contains only entry but no feed xml_entries = _get_children_from_path(xmldoc, 'entry') if inspect.isclass(convert_callback) and issubclass(convert_callback, WindowsAzureData): for xml_entry in xml_entries: return_obj = convert_callback() for node in _get_children_from_path(xml_entry, 'content', convert_callback.__name__): _fill_data_to_return_object(node, return_obj) for name, value in _get_entry_properties_from_node(xml_entry, include_id=True, use_title_as_id=True).items(): setattr(return_obj, name, value) feeds.append(return_obj) else: for xml_entry in xml_entries: new_node = _clone_node_with_namespaces(xml_entry, xmldoc) feeds.append(convert_callback(new_node.toxml('utf-8'))) return feeds def _validate_type_bytes(param_name, param): if not isinstance(param, bytes): raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _validate_not_none(param_name, param): if param is None: raise TypeError(_ERROR_VALUE_NONE.format(param_name)) def _fill_list_of(xmldoc, element_type, xml_element_name): xmlelements = _get_child_nodes(xmldoc, xml_element_name) return [_parse_response_body_from_xml_node(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_scalar_list_of(xmldoc, element_type, parent_xml_element_name, xml_element_name): '''Converts an xml fragment into a list of scalar types. The parent xml element contains a flat list of xml elements which are converted into the specified scalar type and added to the list. Example: xmldoc= <Endpoints> <Endpoint>http://{storage-service-name}.blob.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.queue.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.table.core.windows.net/</Endpoint> </Endpoints> element_type=str parent_xml_element_name='Endpoints' xml_element_name='Endpoint' ''' xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], xml_element_name) return [_get_node_value(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_dict(xmldoc, element_name): xmlelements = _get_child_nodes(xmldoc, element_name) if xmlelements: return_obj = {} for child in xmlelements[0].childNodes: if child.firstChild: return_obj[child.nodeName] = child.firstChild.nodeValue return return_obj def _fill_dict_of(xmldoc, parent_xml_element_name, pair_xml_element_name, key_xml_element_name, value_xml_element_name): '''Converts an xml fragment into a dictionary. The parent xml element contains a list of xml elements where each element has a child element for the key, and another for the value. Example: xmldoc= <ExtendedProperties> <ExtendedProperty> <Name>Ext1</Name> <Value>Val1</Value> </ExtendedProperty> <ExtendedProperty> <Name>Ext2</Name> <Value>Val2</Value> </ExtendedProperty> </ExtendedProperties> element_type=str parent_xml_element_name='ExtendedProperties' pair_xml_element_name='ExtendedProperty' key_xml_element_name='Name' value_xml_element_name='Value' ''' return_obj = {} xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], pair_xml_element_name) for pair in xmlelements: keys = _get_child_nodes(pair, key_xml_element_name) values = _get_child_nodes(pair, value_xml_element_name) if keys and values: key = keys[0].firstChild.nodeValue value = values[0].firstChild.nodeValue return_obj[key] = value return return_obj def _fill_instance_child(xmldoc, element_name, return_type): '''Converts a child of the current dom element to the specified type. ''' xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements: return None return_obj = return_type() _fill_data_to_return_object(xmlelements[0], return_obj) return return_obj def _fill_instance_element(element, return_type): """Converts a DOM element into the specified object""" return _parse_response_body_from_xml_node(element, return_type) def _fill_data_minidom(xmldoc, element_name, data_member): xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements or not xmlelements[0].childNodes: return None value = xmlelements[0].firstChild.nodeValue if data_member is None: return value elif isinstance(data_member, datetime): return _to_datetime(value) elif type(data_member) is bool: return value.lower() != 'false' else: return type(data_member)(value) def _get_node_value(xmlelement, data_type): value = xmlelement.firstChild.nodeValue if data_type is datetime: return _to_datetime(value) elif data_type is bool: return value.lower() != 'false' else: return data_type(value) def _get_request_body_bytes_only(param_name, param_value): '''Validates the request body passed in and converts it to bytes if our policy allows it.''' if param_value is None: return b'' if isinstance(param_value, bytes): return param_value # Previous versions of the SDK allowed data types other than bytes to be # passed in, and they would be auto-converted to bytes. We preserve this # behavior when running under 2.7, but issue a warning. # Python 3 support is new, so we reject anything that's not bytes. if sys.version_info < (3,): warnings.warn(_WARNING_VALUE_SHOULD_BE_BYTES.format(param_name)) return _get_request_body(param_value) raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _get_request_body(request_body): '''Converts an object into a request body. If it's None we'll return an empty string, if it's one of our objects it'll convert it to XML and return it. Otherwise we just use the object directly''' if request_body is None: return b'' if isinstance(request_body, WindowsAzureData): request_body = _convert_class_to_xml(request_body) if isinstance(request_body, bytes): return request_body if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') request_body = str(request_body) if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') return request_body def _parse_enum_results_list(response, return_type, resp_type, item_type): """resp_body is the XML we received resp_type is a string, such as Containers, return_type is the type we're constructing, such as ContainerEnumResults item_type is the type object of the item to be created, such as Container This function then returns a ContainerEnumResults object with the containers member populated with the results. """ # parsing something like: # <EnumerationResults ... > # <Queues> # <Queue> # <Something /> # <SomethingElse /> # </Queue> # </Queues> # </EnumerationResults> respbody = response.body return_obj = return_type() doc = minidom.parseString(respbody) items = [] for enum_results in _get_child_nodes(doc, 'EnumerationResults'): # path is something like Queues, Queue for child in _get_children_from_path(enum_results, resp_type, resp_type[:-1]): items.append(_fill_instance_element(child, item_type)) for name, value in vars(return_obj).items(): # queues, Queues, this is the list its self which we populated # above if name == resp_type.lower(): # the list its self. continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) setattr(return_obj, resp_type.lower(), items) return return_obj def _parse_simple_list(response, type, item_type, list_name): respbody = response.body res = type() res_items = [] doc = minidom.parseString(respbody) type_name = type.__name__ item_name = item_type.__name__ for item in _get_children_from_path(doc, type_name, item_name): res_items.append(_fill_instance_element(item, item_type)) setattr(res, list_name, res_items) return res def _parse_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_xml_text(response.body, return_type) def _parse_service_resources_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_service_resources_xml_text(response.body, return_type) def _fill_data_to_return_object(node, return_obj): members = dict(vars(return_obj)) for name, value in members.items(): if isinstance(value, _list_of): setattr(return_obj, name, _fill_list_of(node, value.list_type, value.xml_element_name)) elif isinstance(value, _scalar_list_of): setattr(return_obj, name, _fill_scalar_list_of(node, value.list_type, _get_serialization_name(name), value.xml_element_name)) elif isinstance(value, _dict_of): setattr(return_obj, name, _fill_dict_of(node, _get_serialization_name(name), value.pair_xml_element_name, value.key_xml_element_name, value.value_xml_element_name)) elif isinstance(value, _xml_attribute): real_value = None if node.hasAttribute(value.xml_element_name): real_value = node.getAttribute(value.xml_element_name) if real_value is not None: setattr(return_obj, name, real_value) elif isinstance(value, WindowsAzureData): setattr(return_obj, name, _fill_instance_child(node, name, value.__class__)) elif isinstance(value, dict): setattr(return_obj, name, _fill_dict(node, _get_serialization_name(name))) elif isinstance(value, _Base64String): value = _fill_data_minidom(node, name, '') if value is not None: value = _decode_base64_to_text(value) # always set the attribute, so we don't end up returning an object # with type _Base64String setattr(return_obj, name, value) else: value = _fill_data_minidom(node, name, value) if value is not None: setattr(return_obj, name, value) def _parse_response_body_from_xml_node(node, return_type): ''' parse the xml and fill all the data into a class of return_type ''' return_obj = return_type() _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = return_type() xml_name = return_type._xml_name if hasattr(return_type, '_xml_name') else return_type.__name__ for node in _get_child_nodes(doc, xml_name): _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_service_resources_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = _list_of(return_type) for node in _get_children_from_path(doc, "ServiceResources", "ServiceResource"): local_obj = return_type() _fill_data_to_return_object(node, local_obj) return_obj.append(local_obj) return return_obj class _dict_of(dict): """a dict which carries with it the xml element names for key,val. Used for deserializaion and construction of the lists""" def __init__(self, pair_xml_element_name, key_xml_element_name, value_xml_element_name): self.pair_xml_element_name = pair_xml_element_name self.key_xml_element_name = key_xml_element_name self.value_xml_element_name = value_xml_element_name super(_dict_of, self).__init__() class _list_of(list): """a list which carries with it the type that's expected to go in it. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name=None): self.list_type = list_type if xml_element_name is None: self.xml_element_name = list_type.__name__ else: self.xml_element_name = xml_element_name super(_list_of, self).__init__() class _scalar_list_of(list): """a list of scalar types which carries with it the type that's expected to go in it along with its xml element name. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name): self.list_type = list_type self.xml_element_name = xml_element_name super(_scalar_list_of, self).__init__() class _xml_attribute: """a accessor to XML attributes expected to go in it along with its xml element name. Used for deserialization and construction""" def __init__(self, xml_element_name): self.xml_element_name = xml_element_name def _update_request_uri_query_local_storage(request, use_local_storage): ''' create correct uri and query for the request ''' uri, query = _update_request_uri_query(request) if use_local_storage: return '/' + DEV_ACCOUNT_NAME + uri, query return uri, query def _update_request_uri_query(request): '''pulls the query string out of the URI and moves it into the query portion of the request object. If there are already query parameters on the request the parameters in the URI will appear after the existing parameters''' if '?' in request.path: request.path, _, query_string = request.path.partition('?') if query_string: query_params = query_string.split('&') for query in query_params: if '=' in query: name, _, value = query.partition('=') request.query.append((name, value)) request.path = url_quote(request.path, '/()$=\',') # add encoded queries to request.path. if request.query: request.path += '?' for name, value in request.query: if value is not None: request.path += name + '=' + url_quote(value, '/()$=\',') + '&' request.path = request.path[:-1] return request.path, request.query def _dont_fail_on_exist(error): ''' don't throw exception if the resource exists. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureConflictError): return False else: raise error def _dont_fail_not_exist(error): ''' don't throw exception if the resource doesn't exist. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureMissingResourceError): return False else: raise error def _general_error_handler(http_error): ''' Simple error handler for azure.''' if http_error.status == 409: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(str(http_error))) elif http_error.status == 404: raise WindowsAzureMissingResourceError( _ERROR_NOT_FOUND.format(str(http_error))) else: if http_error.respbody is not None: raise WindowsAzureError( _ERROR_UNKNOWN.format(str(http_error)) + '\n' + \ http_error.respbody.decode('utf-8')) else: raise WindowsAzureError(_ERROR_UNKNOWN.format(str(http_error))) def _parse_response_for_dict(response): ''' Extracts name-values from response header. Filter out the standard http headers.''' if response is None: return None http_headers = ['server', 'date', 'location', 'host', 'via', 'proxy-connection', 'connection'] return_dict = HeaderDict() if response.headers: for name, value in response.headers: if not name.lower() in http_headers: return_dict[name] = value return return_dict def _parse_response_for_dict_prefix(response, prefixes): ''' Extracts name-values for names starting with prefix from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): for prefix_value in prefixes: if name.lower().startswith(prefix_value.lower()): return_dict[name] = value break return return_dict else: return None def _parse_response_for_dict_filter(response, filter): ''' Extracts name-values for names in filter from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): if name.lower() in filter: return_dict[name] = value return return_dict else: return None def _sign_string(key, string_to_sign, key_is_base64=True): if key_is_base64: key = _decode_base64_to_bytes(key) else: if isinstance(key, _unicode_type): key = key.encode('utf-8') if isinstance(string_to_sign, _unicode_type): string_to_sign = string_to_sign.encode('utf-8') signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = _encode_base64(digest) return encoded_digest ================================================ FILE: DSC/azure/azure.pyproj ================================================ <?xml version="1.0" encoding="utf-8"?> <Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0"> <PropertyGroup> <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> <SchemaVersion>2.0</SchemaVersion> <ProjectGuid>{25b2c65a-0553-4452-8907-8b5b17544e68}</ProjectGuid> <ProjectHome> </ProjectHome> <StartupFile>storage\blobservice.py</StartupFile> <SearchPath>..</SearchPath> <WorkingDirectory>.</WorkingDirectory> <OutputPath>.</OutputPath> <Name>azure</Name> <RootNamespace>azure</RootNamespace> <IsWindowsApplication>False</IsWindowsApplication> <LaunchProvider>Standard Python launcher</LaunchProvider> <CommandLineArguments /> <InterpreterPath /> <InterpreterArguments /> <InterpreterId>{9a7a9026-48c1-4688-9d5d-e5699d47d074}</InterpreterId> <InterpreterVersion>3.4</InterpreterVersion> <SccProjectName>SAK</SccProjectName> <SccProvider>SAK</SccProvider> <SccAuxPath>SAK</SccAuxPath> <SccLocalPath>SAK</SccLocalPath> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Debug' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Release' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <ItemGroup> <Compile Include="http\batchclient.py" /> <Compile Include="http\httpclient.py" /> <Compile Include="http\winhttp.py" /> <Compile Include="http\__init__.py" /> <Compile Include="servicemanagement\servicebusmanagementservice.py" /> <Compile Include="servicemanagement\servicemanagementclient.py" /> <Compile Include="servicemanagement\servicemanagementservice.py" /> <Compile Include="servicemanagement\sqldatabasemanagementservice.py" /> <Compile Include="servicemanagement\websitemanagementservice.py" /> <Compile Include="servicemanagement\__init__.py" /> <Compile Include="servicebus\servicebusservice.py" /> <Compile Include="storage\blobservice.py" /> <Compile Include="storage\queueservice.py" /> <Compile Include="storage\cloudstorageaccount.py" /> <Compile Include="storage\tableservice.py" /> <Compile Include="storage\sharedaccesssignature.py" /> <Compile Include="__init__.py" /> <Compile Include="servicebus\__init__.py" /> <Compile Include="storage\storageclient.py" /> <Compile Include="storage\__init__.py" /> </ItemGroup> <ItemGroup> <Folder Include="http" /> <Folder Include="servicemanagement" /> <Folder Include="servicebus\" /> <Folder Include="storage" /> </ItemGroup> <ItemGroup> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.6" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.7" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.3" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.4" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\2.7" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.3" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.4" /> </ItemGroup> <PropertyGroup> <VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion> <VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath> <PtvsTargetsFile>$(VSToolsPath)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile> </PropertyGroup> <Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" /> <Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" /> </Project> ================================================ FILE: DSC/azure/http/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- HTTP_RESPONSE_NO_CONTENT = 204 class HTTPError(Exception): ''' HTTP Exception when response status code >= 300 ''' def __init__(self, status, message, respheader, respbody): '''Creates a new HTTPError with the specified status, message, response headers and body''' self.status = status self.respheader = respheader self.respbody = respbody Exception.__init__(self, message) class HTTPResponse(object): """Represents a response from an HTTP request. An HTTPResponse has the following attributes: status: the status code of the response message: the message headers: the returned headers, as a list of (name, value) pairs body: the body of the response """ def __init__(self, status, message, headers, body): self.status = status self.message = message self.headers = headers self.body = body class HTTPRequest(object): '''Represents an HTTP Request. An HTTP Request consists of the following attributes: host: the host name to connect to method: the method to use to connect (string such as GET, POST, PUT, etc.) path: the uri fragment query: query parameters specified as a list of (name, value) pairs headers: header values specified as (name, value) pairs body: the body of the request. protocol_override: specify to use this protocol instead of the global one stored in _HTTPClient. ''' def __init__(self): self.host = '' self.method = '' self.path = '' self.query = [] # list of (name, value) self.headers = [] # list of (header name, header value) self.body = '' self.protocol_override = None ================================================ FILE: DSC/azure/http/batchclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import uuid from azure import ( _update_request_uri_query, WindowsAzureError, WindowsAzureBatchOperationError, _get_children_from_path, url_unquote, _ERROR_CANNOT_FIND_PARTITION_KEY, _ERROR_CANNOT_FIND_ROW_KEY, _ERROR_INCORRECT_TABLE_IN_BATCH, _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH, _ERROR_DUPLICATE_ROW_KEY_IN_BATCH, _ERROR_BATCH_COMMIT_FAIL, ) from azure.http import HTTPError, HTTPRequest, HTTPResponse from azure.http.httpclient import _HTTPClient from azure.storage import ( _update_storage_table_header, METADATA_NS, _sign_storage_table_request, ) from xml.dom import minidom _DATASERVICES_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices' if sys.version_info < (3,): def _new_boundary(): return str(uuid.uuid1()) else: def _new_boundary(): return str(uuid.uuid1()).encode('utf-8') class _BatchClient(_HTTPClient): ''' This is the class that is used for batch operation for storage table service. It only supports one changeset. ''' def __init__(self, service_instance, account_key, account_name, protocol='http'): _HTTPClient.__init__(self, service_instance, account_name=account_name, account_key=account_key, protocol=protocol) self.is_batch = False self.batch_requests = [] self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] def get_request_table(self, request): ''' Extracts table name from request.uri. The request.uri has either "/mytable(...)" or "/mytable" format. request: the request to insert, update or delete entity ''' if '(' in request.path: pos = request.path.find('(') return request.path[1:pos] else: return request.path[1:] def get_request_partition_key(self, request): ''' Extracts PartitionKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the PartitionKey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) part_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'PartitionKey')) if not part_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return part_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('PartitionKey=\'') pos2 = uri.find('\',', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return uri[pos1 + len('PartitionKey=\''):pos2] def get_request_row_key(self, request): ''' Extracts RowKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the Rowkey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) row_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'RowKey')) if not row_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) return row_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('RowKey=\'') pos2 = uri.find('\')', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) row_key = uri[pos1 + len('RowKey=\''):pos2] return row_key def validate_request_table(self, request): ''' Validates that all requests have the same table name. Set the table name if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_table: if self.get_request_table(request) != self.batch_table: raise WindowsAzureError(_ERROR_INCORRECT_TABLE_IN_BATCH) else: self.batch_table = self.get_request_table(request) def validate_request_partition_key(self, request): ''' Validates that all requests have the same PartitiionKey. Set the PartitionKey if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_partition_key: if self.get_request_partition_key(request) != \ self.batch_partition_key: raise WindowsAzureError(_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH) else: self.batch_partition_key = self.get_request_partition_key(request) def validate_request_row_key(self, request): ''' Validates that all requests have the different RowKey and adds RowKey to existing RowKey list. request: the request to insert, update or delete entity ''' if self.batch_row_keys: if self.get_request_row_key(request) in self.batch_row_keys: raise WindowsAzureError(_ERROR_DUPLICATE_ROW_KEY_IN_BATCH) else: self.batch_row_keys.append(self.get_request_row_key(request)) def begin_batch(self): ''' Starts the batch operation. Intializes the batch variables is_batch: batch operation flag. batch_table: the table name of the batch operation batch_partition_key: the PartitionKey of the batch requests. batch_row_keys: the RowKey list of adding requests. batch_requests: the list of the requests. ''' self.is_batch = True self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] self.batch_requests = [] def insert_request_to_batch(self, request): ''' Adds request to batch operation. request: the request to insert, update or delete entity ''' self.validate_request_table(request) self.validate_request_partition_key(request) self.validate_request_row_key(request) self.batch_requests.append(request) def commit_batch(self): ''' Resets batch flag and commits the batch requests. ''' if self.is_batch: self.is_batch = False self.commit_batch_requests() def commit_batch_requests(self): ''' Commits the batch requests. ''' batch_boundary = b'batch_' + _new_boundary() changeset_boundary = b'changeset_' + _new_boundary() # Commits batch only the requests list is not empty. if self.batch_requests: request = HTTPRequest() request.method = 'POST' request.host = self.batch_requests[0].host request.path = '/$batch' request.headers = [ ('Content-Type', 'multipart/mixed; boundary=' + \ batch_boundary.decode('utf-8')), ('Accept', 'application/atom+xml,application/xml'), ('Accept-Charset', 'UTF-8')] request.body = b'--' + batch_boundary + b'\n' request.body += b'Content-Type: multipart/mixed; boundary=' request.body += changeset_boundary + b'\n\n' content_id = 1 # Adds each request body to the POST data. for batch_request in self.batch_requests: request.body += b'--' + changeset_boundary + b'\n' request.body += b'Content-Type: application/http\n' request.body += b'Content-Transfer-Encoding: binary\n\n' request.body += batch_request.method.encode('utf-8') request.body += b' http://' request.body += batch_request.host.encode('utf-8') request.body += batch_request.path.encode('utf-8') request.body += b' HTTP/1.1\n' request.body += b'Content-ID: ' request.body += str(content_id).encode('utf-8') + b'\n' content_id += 1 # Add different headers for different type requests. if not batch_request.method == 'DELETE': request.body += \ b'Content-Type: application/atom+xml;type=entry\n' for name, value in batch_request.headers: if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n' break request.body += b'Content-Length: ' request.body += str(len(batch_request.body)).encode('utf-8') request.body += b'\n\n' request.body += batch_request.body + b'\n' else: for name, value in batch_request.headers: # If-Match should be already included in # batch_request.headers, but in case it is missing, # just add it. if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n\n' break else: request.body += b'If-Match: *\n\n' request.body += b'--' + changeset_boundary + b'--' + b'\n' request.body += b'--' + batch_boundary + b'--' request.path, request.query = _update_request_uri_query(request) request.headers = _update_storage_table_header(request) auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) # Submit the whole request as batch request. response = self.perform_request(request) if response.status >= 300: raise HTTPError(response.status, _ERROR_BATCH_COMMIT_FAIL, self.respheader, response.body) # http://www.odata.org/documentation/odata-version-2-0/batch-processing/ # The body of a ChangeSet response is either a response for all the # successfully processed change request within the ChangeSet, # formatted exactly as it would have appeared outside of a batch, # or a single response indicating a failure of the entire ChangeSet. responses = self._parse_batch_response(response.body) if responses and responses[0].status >= 300: self._report_batch_error(responses[0]) def cancel_batch(self): ''' Resets the batch flag. ''' self.is_batch = False def _parse_batch_response(self, body): parts = body.split(b'--changesetresponse_') responses = [] for part in parts: httpLocation = part.find(b'HTTP/') if httpLocation > 0: response = self._parse_batch_response_part(part[httpLocation:]) responses.append(response) return responses def _parse_batch_response_part(self, part): lines = part.splitlines(); # First line is the HTTP status/reason status, _, reason = lines[0].partition(b' ')[2].partition(b' ') # Followed by headers and body headers = [] body = b'' isBody = False for line in lines[1:]: if line == b'' and not isBody: isBody = True elif isBody: body += line else: headerName, _, headerVal = line.partition(b':') headers.append((headerName.lower(), headerVal)) return HTTPResponse(int(status), reason.strip(), headers, body) def _report_batch_error(self, response): xml = response.body.decode('utf-8') doc = minidom.parseString(xml) n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'code') code = n[0].firstChild.nodeValue if n and n[0].firstChild else '' n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'message') message = n[0].firstChild.nodeValue if n and n[0].firstChild else xml raise WindowsAzureBatchOperationError(message, code) ================================================ FILE: DSC/azure/http/httpclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import base64 import os import sys if sys.version_info < (3,): from httplib import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urlparse import urlparse else: from http.client import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urllib.parse import urlparse from azure.http import HTTPError, HTTPResponse from azure import _USER_AGENT_STRING, _update_request_uri_query class _HTTPClient(object): ''' Takes the request and sends it to cloud service and returns the response. ''' def __init__(self, service_instance, cert_file=None, account_name=None, account_key=None, protocol='https'): ''' service_instance: service client instance. cert_file: certificate file name/location. This is only used in hosted service management. account_name: the storage account. account_key: the storage account access key. ''' self.service_instance = service_instance self.status = None self.respheader = None self.message = None self.cert_file = cert_file self.account_name = account_name self.account_key = account_key self.protocol = protocol self.proxy_host = None self.proxy_port = None self.proxy_user = None self.proxy_password = None self.use_httplib = self.should_use_httplib() def should_use_httplib(self): if sys.platform.lower().startswith('win') and self.cert_file: # On Windows, auto-detect between Windows Store Certificate # (winhttp) and OpenSSL .pem certificate file (httplib). # # We used to only support certificates installed in the Windows # Certificate Store. # cert_file example: CURRENT_USER\my\CertificateName # # We now support using an OpenSSL .pem certificate file, # for a consistent experience across all platforms. # cert_file example: account\certificate.pem # # When using OpenSSL .pem certificate file on Windows, make sure # you are on CPython 2.7.4 or later. # If it's not an existing file on disk, then treat it as a path in # the Windows Certificate Store, which means we can't use httplib. if not os.path.isfile(self.cert_file): return False return True def set_proxy(self, host, port, user, password): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self.proxy_host = host self.proxy_port = port self.proxy_user = user self.proxy_password = password def get_uri(self, request): ''' Return the target uri for the request.''' protocol = request.protocol_override \ if request.protocol_override else self.protocol port = HTTP_PORT if protocol == 'http' else HTTPS_PORT return protocol + '://' + request.host + ':' + str(port) + request.path def get_connection(self, request): ''' Create connection for the request. ''' protocol = request.protocol_override \ if request.protocol_override else self.protocol target_host = request.host target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT if not self.use_httplib: import azure.http.winhttp connection = azure.http.winhttp._HTTPConnection( target_host, cert_file=self.cert_file, protocol=protocol) proxy_host = self.proxy_host proxy_port = self.proxy_port else: if ':' in target_host: target_host, _, target_port = target_host.rpartition(':') if self.proxy_host: proxy_host = target_host proxy_port = target_port host = self.proxy_host port = self.proxy_port else: host = target_host port = target_port if protocol == 'http': connection = HTTPConnection(host, int(port)) else: connection = HTTPSConnection( host, int(port), cert_file=self.cert_file) if self.proxy_host: headers = None if self.proxy_user and self.proxy_password: auth = base64.encodestring( "{0}:{1}".format(self.proxy_user, self.proxy_password)) headers = {'Proxy-Authorization': 'Basic {0}'.format(auth)} connection.set_tunnel(proxy_host, int(proxy_port), headers) return connection def send_request_headers(self, connection, request_headers): if self.use_httplib: if self.proxy_host: for i in connection._buffer: if i.startswith("Host: "): connection._buffer.remove(i) connection.putheader( 'Host', "{0}:{1}".format(connection._tunnel_host, connection._tunnel_port)) for name, value in request_headers: if value: connection.putheader(name, value) connection.putheader('User-Agent', _USER_AGENT_STRING) connection.endheaders() def send_request_body(self, connection, request_body): if request_body: assert isinstance(request_body, bytes) connection.send(request_body) elif (not isinstance(connection, HTTPSConnection) and not isinstance(connection, HTTPConnection)): connection.send(None) def perform_request(self, request): ''' Sends request to cloud service server and return the response. ''' connection = self.get_connection(request) try: connection.putrequest(request.method, request.path) if not self.use_httplib: if self.proxy_host and self.proxy_user: connection.set_proxy_credentials( self.proxy_user, self.proxy_password) self.send_request_headers(connection, request.headers) self.send_request_body(connection, request.body) resp = connection.getresponse() self.status = int(resp.status) self.message = resp.reason self.respheader = headers = resp.getheaders() # for consistency across platforms, make header names lowercase for i, value in enumerate(headers): headers[i] = (value[0].lower(), value[1]) respbody = None if resp.length is None: respbody = resp.read() elif resp.length > 0: respbody = resp.read(resp.length) response = HTTPResponse( int(resp.status), resp.reason, headers, respbody) if self.status == 307: new_url = urlparse(dict(headers)['location']) request.host = new_url.hostname request.path = new_url.path request.path, request.query = _update_request_uri_query(request) return self.perform_request(request) if self.status >= 300: raise HTTPError(self.status, self.message, self.respheader, respbody) return response finally: connection.close() ================================================ FILE: DSC/azure/http/winhttp.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from ctypes import ( c_void_p, c_long, c_ulong, c_longlong, c_ulonglong, c_short, c_ushort, c_wchar_p, c_byte, byref, Structure, Union, POINTER, WINFUNCTYPE, HRESULT, oledll, WinDLL, ) import ctypes import sys if sys.version_info >= (3,): def unicode(text): return text #------------------------------------------------------------------------------ # Constants that are used in COM operations VT_EMPTY = 0 VT_NULL = 1 VT_I2 = 2 VT_I4 = 3 VT_BSTR = 8 VT_BOOL = 11 VT_I1 = 16 VT_UI1 = 17 VT_UI2 = 18 VT_UI4 = 19 VT_I8 = 20 VT_UI8 = 21 VT_ARRAY = 8192 HTTPREQUEST_PROXYSETTING_PROXY = 2 HTTPREQUEST_SETCREDENTIALS_FOR_PROXY = 1 HTTPREQUEST_PROXY_SETTING = c_long HTTPREQUEST_SETCREDENTIALS_FLAGS = c_long #------------------------------------------------------------------------------ # Com related APIs that are used. _ole32 = oledll.ole32 _oleaut32 = WinDLL('oleaut32') _CLSIDFromString = _ole32.CLSIDFromString _CoInitialize = _ole32.CoInitialize _CoInitialize.argtypes = [c_void_p] _CoCreateInstance = _ole32.CoCreateInstance _SysAllocString = _oleaut32.SysAllocString _SysAllocString.restype = c_void_p _SysAllocString.argtypes = [c_wchar_p] _SysFreeString = _oleaut32.SysFreeString _SysFreeString.argtypes = [c_void_p] # SAFEARRAY* # SafeArrayCreateVector(_In_ VARTYPE vt,_In_ LONG lLbound,_In_ ULONG # cElements); _SafeArrayCreateVector = _oleaut32.SafeArrayCreateVector _SafeArrayCreateVector.restype = c_void_p _SafeArrayCreateVector.argtypes = [c_ushort, c_long, c_ulong] # HRESULT # SafeArrayAccessData(_In_ SAFEARRAY *psa, _Out_ void **ppvData); _SafeArrayAccessData = _oleaut32.SafeArrayAccessData _SafeArrayAccessData.argtypes = [c_void_p, POINTER(c_void_p)] # HRESULT # SafeArrayUnaccessData(_In_ SAFEARRAY *psa); _SafeArrayUnaccessData = _oleaut32.SafeArrayUnaccessData _SafeArrayUnaccessData.argtypes = [c_void_p] # HRESULT # SafeArrayGetUBound(_In_ SAFEARRAY *psa, _In_ UINT nDim, _Out_ LONG # *plUbound); _SafeArrayGetUBound = _oleaut32.SafeArrayGetUBound _SafeArrayGetUBound.argtypes = [c_void_p, c_ulong, POINTER(c_long)] #------------------------------------------------------------------------------ class BSTR(c_wchar_p): ''' BSTR class in python. ''' def __init__(self, value): super(BSTR, self).__init__(_SysAllocString(value)) def __del__(self): _SysFreeString(self) class VARIANT(Structure): ''' VARIANT structure in python. Does not match the definition in MSDN exactly & it is only mapping the used fields. Field names are also slighty different. ''' class _tagData(Union): class _tagRecord(Structure): _fields_ = [('pvoid', c_void_p), ('precord', c_void_p)] _fields_ = [('llval', c_longlong), ('ullval', c_ulonglong), ('lval', c_long), ('ulval', c_ulong), ('ival', c_short), ('boolval', c_ushort), ('bstrval', BSTR), ('parray', c_void_p), ('record', _tagRecord)] _fields_ = [('vt', c_ushort), ('wReserved1', c_ushort), ('wReserved2', c_ushort), ('wReserved3', c_ushort), ('vdata', _tagData)] @staticmethod def create_empty(): variant = VARIANT() variant.vt = VT_EMPTY variant.vdata.llval = 0 return variant @staticmethod def create_safearray_from_str(text): variant = VARIANT() variant.vt = VT_ARRAY | VT_UI1 length = len(text) variant.vdata.parray = _SafeArrayCreateVector(VT_UI1, 0, length) pvdata = c_void_p() _SafeArrayAccessData(variant.vdata.parray, byref(pvdata)) ctypes.memmove(pvdata, text, length) _SafeArrayUnaccessData(variant.vdata.parray) return variant @staticmethod def create_bstr_from_str(text): variant = VARIANT() variant.vt = VT_BSTR variant.vdata.bstrval = BSTR(text) return variant @staticmethod def create_bool_false(): variant = VARIANT() variant.vt = VT_BOOL variant.vdata.boolval = 0 return variant def is_safearray_of_bytes(self): return self.vt == VT_ARRAY | VT_UI1 def str_from_safearray(self): assert self.vt == VT_ARRAY | VT_UI1 pvdata = c_void_p() count = c_long() _SafeArrayGetUBound(self.vdata.parray, 1, byref(count)) count = c_long(count.value + 1) _SafeArrayAccessData(self.vdata.parray, byref(pvdata)) text = ctypes.string_at(pvdata, count) _SafeArrayUnaccessData(self.vdata.parray) return text def __del__(self): _VariantClear(self) # HRESULT VariantClear(_Inout_ VARIANTARG *pvarg); _VariantClear = _oleaut32.VariantClear _VariantClear.argtypes = [POINTER(VARIANT)] class GUID(Structure): ''' GUID structure in python. ''' _fields_ = [("data1", c_ulong), ("data2", c_ushort), ("data3", c_ushort), ("data4", c_byte * 8)] def __init__(self, name=None): if name is not None: _CLSIDFromString(unicode(name), byref(self)) class _WinHttpRequest(c_void_p): ''' Maps the Com API to Python class functions. Not all methods in IWinHttpWebRequest are mapped - only the methods we use. ''' _AddRef = WINFUNCTYPE(c_long) \ (1, 'AddRef') _Release = WINFUNCTYPE(c_long) \ (2, 'Release') _SetProxy = WINFUNCTYPE(HRESULT, HTTPREQUEST_PROXY_SETTING, VARIANT, VARIANT) \ (7, 'SetProxy') _SetCredentials = WINFUNCTYPE(HRESULT, BSTR, BSTR, HTTPREQUEST_SETCREDENTIALS_FLAGS) \ (8, 'SetCredentials') _Open = WINFUNCTYPE(HRESULT, BSTR, BSTR, VARIANT) \ (9, 'Open') _SetRequestHeader = WINFUNCTYPE(HRESULT, BSTR, BSTR) \ (10, 'SetRequestHeader') _GetResponseHeader = WINFUNCTYPE(HRESULT, BSTR, POINTER(c_void_p)) \ (11, 'GetResponseHeader') _GetAllResponseHeaders = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (12, 'GetAllResponseHeaders') _Send = WINFUNCTYPE(HRESULT, VARIANT) \ (13, 'Send') _Status = WINFUNCTYPE(HRESULT, POINTER(c_long)) \ (14, 'Status') _StatusText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (15, 'StatusText') _ResponseText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (16, 'ResponseText') _ResponseBody = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (17, 'ResponseBody') _ResponseStream = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (18, 'ResponseStream') _WaitForResponse = WINFUNCTYPE(HRESULT, VARIANT, POINTER(c_ushort)) \ (21, 'WaitForResponse') _Abort = WINFUNCTYPE(HRESULT) \ (22, 'Abort') _SetTimeouts = WINFUNCTYPE(HRESULT, c_long, c_long, c_long, c_long) \ (23, 'SetTimeouts') _SetClientCertificate = WINFUNCTYPE(HRESULT, BSTR) \ (24, 'SetClientCertificate') def open(self, method, url): ''' Opens the request. method: the request VERB 'GET', 'POST', etc. url: the url to connect ''' _WinHttpRequest._SetTimeouts(self, 0, 65000, 65000, 65000) flag = VARIANT.create_bool_false() _method = BSTR(method) _url = BSTR(url) _WinHttpRequest._Open(self, _method, _url, flag) def set_request_header(self, name, value): ''' Sets the request header. ''' _name = BSTR(name) _value = BSTR(value) _WinHttpRequest._SetRequestHeader(self, _name, _value) def get_all_response_headers(self): ''' Gets back all response headers. ''' bstr_headers = c_void_p() _WinHttpRequest._GetAllResponseHeaders(self, byref(bstr_headers)) bstr_headers = ctypes.cast(bstr_headers, c_wchar_p) headers = bstr_headers.value _SysFreeString(bstr_headers) return headers def send(self, request=None): ''' Sends the request body. ''' # Sends VT_EMPTY if it is GET, HEAD request. if request is None: var_empty = VARIANT.create_empty() _WinHttpRequest._Send(self, var_empty) else: # Sends request body as SAFEArray. _request = VARIANT.create_safearray_from_str(request) _WinHttpRequest._Send(self, _request) def status(self): ''' Gets status of response. ''' status = c_long() _WinHttpRequest._Status(self, byref(status)) return int(status.value) def status_text(self): ''' Gets status text of response. ''' bstr_status_text = c_void_p() _WinHttpRequest._StatusText(self, byref(bstr_status_text)) bstr_status_text = ctypes.cast(bstr_status_text, c_wchar_p) status_text = bstr_status_text.value _SysFreeString(bstr_status_text) return status_text def response_body(self): ''' Gets response body as a SAFEARRAY and converts the SAFEARRAY to str. If it is an xml file, it always contains 3 characters before <?xml, so we remove them. ''' var_respbody = VARIANT() _WinHttpRequest._ResponseBody(self, byref(var_respbody)) if var_respbody.is_safearray_of_bytes(): respbody = var_respbody.str_from_safearray() if respbody[3:].startswith(b'<?xml') and\ respbody.startswith(b'\xef\xbb\xbf'): respbody = respbody[3:] return respbody else: return '' def set_client_certificate(self, certificate): '''Sets client certificate for the request. ''' _certificate = BSTR(certificate) _WinHttpRequest._SetClientCertificate(self, _certificate) def set_tunnel(self, host, port): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling.''' url = host if port: url = url + u':' + port var_host = VARIANT.create_bstr_from_str(url) var_empty = VARIANT.create_empty() _WinHttpRequest._SetProxy( self, HTTPREQUEST_PROXYSETTING_PROXY, var_host, var_empty) def set_proxy_credentials(self, user, password): _WinHttpRequest._SetCredentials( self, BSTR(user), BSTR(password), HTTPREQUEST_SETCREDENTIALS_FOR_PROXY) def __del__(self): if self.value is not None: _WinHttpRequest._Release(self) class _Response(object): ''' Response class corresponding to the response returned from httplib HTTPConnection. ''' def __init__(self, _status, _status_text, _length, _headers, _respbody): self.status = _status self.reason = _status_text self.length = _length self.headers = _headers self.respbody = _respbody def getheaders(self): '''Returns response headers.''' return self.headers def read(self, _length): '''Returns resonse body. ''' return self.respbody[:_length] class _HTTPConnection(object): ''' Class corresponding to httplib HTTPConnection class. ''' def __init__(self, host, cert_file=None, key_file=None, protocol='http'): ''' initialize the IWinHttpWebRequest Com Object.''' self.host = unicode(host) self.cert_file = cert_file self._httprequest = _WinHttpRequest() self.protocol = protocol clsid = GUID('{2087C2F4-2CEF-4953-A8AB-66779B670495}') iid = GUID('{016FE2EC-B2C8-45F8-B23B-39E53A75396B}') _CoInitialize(None) _CoCreateInstance(byref(clsid), 0, 1, byref(iid), byref(self._httprequest)) def close(self): pass def set_tunnel(self, host, port=None, headers=None): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling. ''' self._httprequest.set_tunnel(unicode(host), unicode(str(port))) def set_proxy_credentials(self, user, password): self._httprequest.set_proxy_credentials( unicode(user), unicode(password)) def putrequest(self, method, uri): ''' Connects to host and sends the request. ''' protocol = unicode(self.protocol + '://') url = protocol + self.host + unicode(uri) self._httprequest.open(unicode(method), url) # sets certificate for the connection if cert_file is set. if self.cert_file is not None: self._httprequest.set_client_certificate(unicode(self.cert_file)) def putheader(self, name, value): ''' Sends the headers of request. ''' if sys.version_info < (3,): name = str(name).decode('utf-8') value = str(value).decode('utf-8') self._httprequest.set_request_header(name, value) def endheaders(self): ''' No operation. Exists only to provide the same interface of httplib HTTPConnection.''' pass def send(self, request_body): ''' Sends request body. ''' if not request_body: self._httprequest.send() else: self._httprequest.send(request_body) def getresponse(self): ''' Gets the response and generates the _Response object''' status = self._httprequest.status() status_text = self._httprequest.status_text() resp_headers = self._httprequest.get_all_response_headers() fixed_headers = [] for resp_header in resp_headers.split('\n'): if (resp_header.startswith('\t') or\ resp_header.startswith(' ')) and fixed_headers: # append to previous header fixed_headers[-1] += resp_header else: fixed_headers.append(resp_header) headers = [] for resp_header in fixed_headers: if ':' in resp_header: pos = resp_header.find(':') headers.append( (resp_header[:pos].lower(), resp_header[pos + 1:].strip())) body = self._httprequest.response_body() length = len(body) return _Response(status, status_text, length, headers, body) ================================================ FILE: DSC/azure/servicebus/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import json import sys from datetime import datetime from xml.dom import minidom from azure import ( WindowsAzureData, WindowsAzureError, xml_escape, _create_entry, _general_error_handler, _get_entry_properties, _get_child_nodes, _get_children_from_path, _get_first_child_node_value, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK, _ERROR_QUEUE_NOT_FOUND, _ERROR_TOPIC_NOT_FOUND, ) from azure.http import HTTPError # default rule name for subscription DEFAULT_RULE_NAME = '$Default' #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_SERVICEBUS_NAMESPACE = 'AZURE_SERVICEBUS_NAMESPACE' AZURE_SERVICEBUS_ACCESS_KEY = 'AZURE_SERVICEBUS_ACCESS_KEY' AZURE_SERVICEBUS_ISSUER = 'AZURE_SERVICEBUS_ISSUER' # namespace used for converting rules to objects XML_SCHEMA_NAMESPACE = 'http://www.w3.org/2001/XMLSchema-instance' class Queue(WindowsAzureData): ''' Queue class corresponding to Queue Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780773''' def __init__(self, lock_duration=None, max_size_in_megabytes=None, requires_duplicate_detection=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, duplicate_detection_history_time_window=None, max_delivery_count=None, enable_batched_operations=None, size_in_bytes=None, message_count=None): self.lock_duration = lock_duration self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.max_delivery_count = max_delivery_count self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes self.message_count = message_count class Topic(WindowsAzureData): ''' Topic class corresponding to Topic Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780749. ''' def __init__(self, default_message_time_to_live=None, max_size_in_megabytes=None, requires_duplicate_detection=None, duplicate_detection_history_time_window=None, enable_batched_operations=None, size_in_bytes=None): self.default_message_time_to_live = default_message_time_to_live self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes @property def max_size_in_mega_bytes(self): import warnings warnings.warn( 'This attribute has been changed to max_size_in_megabytes.') return self.max_size_in_megabytes @max_size_in_mega_bytes.setter def max_size_in_mega_bytes(self, value): self.max_size_in_megabytes = value class Subscription(WindowsAzureData): ''' Subscription class corresponding to Subscription Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780763. ''' def __init__(self, lock_duration=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, dead_lettering_on_filter_evaluation_exceptions=None, enable_batched_operations=None, max_delivery_count=None, message_count=None): self.lock_duration = lock_duration self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.dead_lettering_on_filter_evaluation_exceptions = \ dead_lettering_on_filter_evaluation_exceptions self.enable_batched_operations = enable_batched_operations self.max_delivery_count = max_delivery_count self.message_count = message_count class Rule(WindowsAzureData): ''' Rule class corresponding to Rule Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780753. ''' def __init__(self, filter_type=None, filter_expression=None, action_type=None, action_expression=None): self.filter_type = filter_type self.filter_expression = filter_expression self.action_type = action_type self.action_expression = action_type class Message(WindowsAzureData): ''' Message class that used in send message/get mesage apis. ''' def __init__(self, body=None, service_bus_service=None, location=None, custom_properties=None, type='application/atom+xml;type=entry;charset=utf-8', broker_properties=None): self.body = body self.location = location self.broker_properties = broker_properties self.custom_properties = custom_properties self.type = type self.service_bus_service = service_bus_service self._topic_name = None self._subscription_name = None self._queue_name = None if not service_bus_service: return # if location is set, then extracts the queue name for queue message and # extracts the topic and subscriptions name if it is topic message. if location: if '/subscriptions/' in location: pos = location.find('/subscriptions/') pos1 = location.rfind('/', 0, pos - 1) self._topic_name = location[pos1 + 1:pos] pos += len('/subscriptions/') pos1 = location.find('/', pos) self._subscription_name = location[pos:pos1] elif '/messages/' in location: pos = location.find('/messages/') pos1 = location.rfind('/', 0, pos - 1) self._queue_name = location[pos1 + 1:pos] def delete(self): ''' Deletes itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.delete_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.delete_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE) def unlock(self): ''' Unlocks itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.unlock_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.unlock_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK) def add_headers(self, request): ''' add addtional headers to request for message request.''' # Adds custom properties if self.custom_properties: for name, value in self.custom_properties.items(): if sys.version_info < (3,) and isinstance(value, unicode): request.headers.append( (name, '"' + value.encode('utf-8') + '"')) elif isinstance(value, str): request.headers.append((name, '"' + str(value) + '"')) elif isinstance(value, datetime): request.headers.append( (name, '"' + value.strftime('%a, %d %b %Y %H:%M:%S GMT') + '"')) else: request.headers.append((name, str(value).lower())) # Adds content-type request.headers.append(('Content-Type', self.type)) # Adds BrokerProperties if self.broker_properties: request.headers.append( ('BrokerProperties', str(self.broker_properties))) return request.headers def _create_message(response, service_instance): ''' Create message from response. response: response from service bus cloud server. service_instance: the service bus client. ''' respbody = response.body custom_properties = {} broker_properties = None message_type = None message_location = None # gets all information from respheaders. for name, value in response.headers: if name.lower() == 'brokerproperties': broker_properties = json.loads(value) elif name.lower() == 'content-type': message_type = value elif name.lower() == 'location': message_location = value elif name.lower() not in ['content-type', 'brokerproperties', 'transfer-encoding', 'server', 'location', 'date']: if '"' in value: value = value[1:-1] try: custom_properties[name] = datetime.strptime( value, '%a, %d %b %Y %H:%M:%S GMT') except ValueError: custom_properties[name] = value else: # only int, float or boolean if value.lower() == 'true': custom_properties[name] = True elif value.lower() == 'false': custom_properties[name] = False # int('3.1') doesn't work so need to get float('3.14') first elif str(int(float(value))) == value: custom_properties[name] = int(value) else: custom_properties[name] = float(value) if message_type == None: message = Message( respbody, service_instance, message_location, custom_properties, 'application/atom+xml;type=entry;charset=utf-8', broker_properties) else: message = Message(respbody, service_instance, message_location, custom_properties, message_type, broker_properties) return message # convert functions def _convert_response_to_rule(response): return _convert_xml_to_rule(response.body) def _convert_xml_to_rule(xmlstr): ''' Converts response xml to rule object. The format of xml for rule: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Filter i:type="SqlFilterExpression"> <SqlExpression>MyProperty='XYZ'</SqlExpression> </Filter> <Action i:type="SqlFilterAction"> <SqlExpression>set MyProperty2 = 'ABC'</SqlExpression> </Action> </RuleDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) rule = Rule() for rule_desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RuleDescription'): for xml_filter in _get_child_nodes(rule_desc, 'Filter'): filter_type = xml_filter.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'filter_type', str(filter_type)) if xml_filter.childNodes: for expr in _get_child_nodes(xml_filter, 'SqlExpression'): setattr(rule, 'filter_expression', expr.firstChild.nodeValue) for xml_action in _get_child_nodes(rule_desc, 'Action'): action_type = xml_action.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'action_type', str(action_type)) if xml_action.childNodes: action_expression = xml_action.childNodes[0].firstChild if action_expression: setattr(rule, 'action_expression', action_expression.nodeValue) # extract id, updated and name value from feed entry and set them of rule. for name, value in _get_entry_properties(xmlstr, True, '/rules').items(): setattr(rule, name, value) return rule def _convert_response_to_queue(response): return _convert_xml_to_queue(response.body) def _parse_bool(value): if value.lower() == 'true': return True return False def _convert_xml_to_queue(xmlstr): ''' Converts xml response to queue object. The format of xml response for queue: <QueueDescription xmlns=\"http://schemas.microsoft.com/netservices/2010/10/servicebus/connect\"> <MaxSizeInBytes>10000</MaxSizeInBytes> <DefaultMessageTimeToLive>PT5M</DefaultMessageTimeToLive> <LockDuration>PT2M</LockDuration> <RequiresGroupedReceives>False</RequiresGroupedReceives> <SupportsDuplicateDetection>False</SupportsDuplicateDetection> ... </QueueDescription> ''' xmldoc = minidom.parseString(xmlstr) queue = Queue() invalid_queue = True # get node for each attribute in Queue class, if nothing found then the # response is not valid xml for Queue. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'QueueDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: queue.lock_duration = node_value invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: queue.max_size_in_megabytes = int(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: queue.requires_duplicate_detection = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'RequiresSession') if node_value is not None: queue.requires_session = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: queue.default_message_time_to_live = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: queue.dead_lettering_on_message_expiration = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: queue.duplicate_detection_history_time_window = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: queue.enable_batched_operations = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxDeliveryCount') if node_value is not None: queue.max_delivery_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MessageCount') if node_value is not None: queue.message_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: queue.size_in_bytes = int(node_value) invalid_queue = False if invalid_queue: raise WindowsAzureError(_ERROR_QUEUE_NOT_FOUND) # extract id, updated and name value from feed entry and set them of queue. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(queue, name, value) return queue def _convert_response_to_topic(response): return _convert_xml_to_topic(response.body) def _convert_xml_to_topic(xmlstr): '''Converts xml response to topic The xml format for topic: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <MaxSizeInMegabytes>1024</MaxSizeInMegabytes> <RequiresDuplicateDetection>false</RequiresDuplicateDetection> <DuplicateDetectionHistoryTimeWindow>P7D</DuplicateDetectionHistoryTimeWindow> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </TopicDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) topic = Topic() invalid_topic = True # get node for each attribute in Topic class, if nothing found then the # response is not valid xml for Topic. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'TopicDescription'): invalid_topic = True node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: topic.default_message_time_to_live = node_value invalid_topic = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: topic.max_size_in_megabytes = int(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: topic.requires_duplicate_detection = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: topic.duplicate_detection_history_time_window = node_value invalid_topic = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: topic.enable_batched_operations = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: topic.size_in_bytes = int(node_value) invalid_topic = False if invalid_topic: raise WindowsAzureError(_ERROR_TOPIC_NOT_FOUND) # extract id, updated and name value from feed entry and set them of topic. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(topic, name, value) return topic def _convert_response_to_subscription(response): return _convert_xml_to_subscription(response.body) def _convert_xml_to_subscription(xmlstr): '''Converts xml response to subscription The xml format for subscription: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <LockDuration>PT5M</LockDuration> <RequiresSession>false</RequiresSession> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <DeadLetteringOnMessageExpiration>false</DeadLetteringOnMessageExpiration> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </SubscriptionDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) subscription = Subscription() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'SubscriptionDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: subscription.lock_duration = node_value node_value = _get_first_child_node_value( desc, 'RequiresSession') if node_value is not None: subscription.requires_session = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: subscription.default_message_time_to_live = node_value node_value = _get_first_child_node_value( desc, 'DeadLetteringOnFilterEvaluationExceptions') if node_value is not None: subscription.dead_lettering_on_filter_evaluation_exceptions = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: subscription.dead_lettering_on_message_expiration = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: subscription.enable_batched_operations = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'MaxDeliveryCount') if node_value is not None: subscription.max_delivery_count = int(node_value) node_value = _get_first_child_node_value( desc, 'MessageCount') if node_value is not None: subscription.message_count = int(node_value) for name, value in _get_entry_properties(xmlstr, True, '/subscriptions').items(): setattr(subscription, name, value) return subscription def _convert_subscription_to_xml(subscription): ''' Converts a subscription object to xml to send. The order of each field of subscription in xml is very important so we can't simple call convert_class_to_xml. subscription: the subsciption object to be converted. ''' subscription_body = '<SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if subscription: if subscription.lock_duration is not None: subscription_body += ''.join( ['<LockDuration>', str(subscription.lock_duration), '</LockDuration>']) if subscription.requires_session is not None: subscription_body += ''.join( ['<RequiresSession>', str(subscription.requires_session).lower(), '</RequiresSession>']) if subscription.default_message_time_to_live is not None: subscription_body += ''.join( ['<DefaultMessageTimeToLive>', str(subscription.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if subscription.dead_lettering_on_message_expiration is not None: subscription_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(subscription.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if subscription.dead_lettering_on_filter_evaluation_exceptions is not None: subscription_body += ''.join( ['<DeadLetteringOnFilterEvaluationExceptions>', str(subscription.dead_lettering_on_filter_evaluation_exceptions).lower(), '</DeadLetteringOnFilterEvaluationExceptions>']) if subscription.enable_batched_operations is not None: subscription_body += ''.join( ['<EnableBatchedOperations>', str(subscription.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if subscription.max_delivery_count is not None: subscription_body += ''.join( ['<MaxDeliveryCount>', str(subscription.max_delivery_count), '</MaxDeliveryCount>']) if subscription.message_count is not None: subscription_body += ''.join( ['<MessageCount>', str(subscription.message_count), '</MessageCount>']) subscription_body += '</SubscriptionDescription>' return _create_entry(subscription_body) def _convert_rule_to_xml(rule): ''' Converts a rule object to xml to send. The order of each field of rule in xml is very important so we cann't simple call convert_class_to_xml. rule: the rule object to be converted. ''' rule_body = '<RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if rule: if rule.filter_type: rule_body += ''.join( ['<Filter i:type="', xml_escape(rule.filter_type), '">']) if rule.filter_type == 'CorrelationFilter': rule_body += ''.join( ['<CorrelationId>', xml_escape(rule.filter_expression), '</CorrelationId>']) else: rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.filter_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Filter>' if rule.action_type: rule_body += ''.join( ['<Action i:type="', xml_escape(rule.action_type), '">']) if rule.action_type == 'SqlRuleAction': rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.action_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Action>' rule_body += '</RuleDescription>' return _create_entry(rule_body) def _convert_topic_to_xml(topic): ''' Converts a topic object to xml to send. The order of each field of topic in xml is very important so we cann't simple call convert_class_to_xml. topic: the topic object to be converted. ''' topic_body = '<TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if topic: if topic.default_message_time_to_live is not None: topic_body += ''.join( ['<DefaultMessageTimeToLive>', str(topic.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if topic.max_size_in_megabytes is not None: topic_body += ''.join( ['<MaxSizeInMegabytes>', str(topic.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if topic.requires_duplicate_detection is not None: topic_body += ''.join( ['<RequiresDuplicateDetection>', str(topic.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if topic.duplicate_detection_history_time_window is not None: topic_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(topic.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if topic.enable_batched_operations is not None: topic_body += ''.join( ['<EnableBatchedOperations>', str(topic.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if topic.size_in_bytes is not None: topic_body += ''.join( ['<SizeInBytes>', str(topic.size_in_bytes), '</SizeInBytes>']) topic_body += '</TopicDescription>' return _create_entry(topic_body) def _convert_queue_to_xml(queue): ''' Converts a queue object to xml to send. The order of each field of queue in xml is very important so we cann't simple call convert_class_to_xml. queue: the queue object to be converted. ''' queue_body = '<QueueDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if queue: if queue.lock_duration: queue_body += ''.join( ['<LockDuration>', str(queue.lock_duration), '</LockDuration>']) if queue.max_size_in_megabytes is not None: queue_body += ''.join( ['<MaxSizeInMegabytes>', str(queue.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if queue.requires_duplicate_detection is not None: queue_body += ''.join( ['<RequiresDuplicateDetection>', str(queue.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if queue.requires_session is not None: queue_body += ''.join( ['<RequiresSession>', str(queue.requires_session).lower(), '</RequiresSession>']) if queue.default_message_time_to_live is not None: queue_body += ''.join( ['<DefaultMessageTimeToLive>', str(queue.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if queue.dead_lettering_on_message_expiration is not None: queue_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(queue.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if queue.duplicate_detection_history_time_window is not None: queue_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(queue.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if queue.max_delivery_count is not None: queue_body += ''.join( ['<MaxDeliveryCount>', str(queue.max_delivery_count), '</MaxDeliveryCount>']) if queue.enable_batched_operations is not None: queue_body += ''.join( ['<EnableBatchedOperations>', str(queue.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if queue.size_in_bytes is not None: queue_body += ''.join( ['<SizeInBytes>', str(queue.size_in_bytes), '</SizeInBytes>']) if queue.message_count is not None: queue_body += ''.join( ['<MessageCount>', str(queue.message_count), '</MessageCount>']) queue_body += '</QueueDescription>' return _create_entry(queue_body) def _service_bus_error_handler(http_error): ''' Simple error handler for service bus service. ''' return _general_error_handler(http_error) from azure.servicebus.servicebusservice import ServiceBusService ================================================ FILE: DSC/azure/servicebus/servicebusservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import datetime import os import time from azure import ( WindowsAzureError, SERVICE_BUS_HOST_BASE, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _sign_string, _str, _unicode_type, _update_request_uri_query, url_quote, url_unquote, _validate_not_none, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicebus import ( AZURE_SERVICEBUS_NAMESPACE, AZURE_SERVICEBUS_ACCESS_KEY, AZURE_SERVICEBUS_ISSUER, _convert_topic_to_xml, _convert_response_to_topic, _convert_queue_to_xml, _convert_response_to_queue, _convert_subscription_to_xml, _convert_response_to_subscription, _convert_rule_to_xml, _convert_response_to_rule, _convert_xml_to_queue, _convert_xml_to_topic, _convert_xml_to_subscription, _convert_xml_to_rule, _create_message, _service_bus_error_handler, ) class ServiceBusService(object): def __init__(self, service_namespace=None, account_key=None, issuer=None, x_ms_version='2011-06-01', host_base=SERVICE_BUS_HOST_BASE, shared_access_key_name=None, shared_access_key_value=None, authentication=None): ''' Initializes the service bus service for a namespace with the specified authentication settings (SAS or ACS). service_namespace: Service bus namespace, required for all operations. If None, the value is set to the AZURE_SERVICEBUS_NAMESPACE env variable. account_key: ACS authentication account key. If None, the value is set to the AZURE_SERVICEBUS_ACCESS_KEY env variable. Note that if both SAS and ACS settings are specified, SAS is used. issuer: ACS authentication issuer. If None, the value is set to the AZURE_SERVICEBUS_ISSUER env variable. Note that if both SAS and ACS settings are specified, SAS is used. x_ms_version: Unused. Kept for backwards compatibility. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. shared_access_key_name: SAS authentication key name. Note that if both SAS and ACS settings are specified, SAS is used. shared_access_key_value: SAS authentication key value. Note that if both SAS and ACS settings are specified, SAS is used. authentication: Instance of authentication class. If this is specified, then ACS and SAS parameters are ignored. ''' self.requestid = None self.service_namespace = service_namespace self.host_base = host_base if not self.service_namespace: self.service_namespace = os.environ.get(AZURE_SERVICEBUS_NAMESPACE) if not self.service_namespace: raise WindowsAzureError('You need to provide servicebus namespace') if authentication: self.authentication = authentication else: if not account_key: account_key = os.environ.get(AZURE_SERVICEBUS_ACCESS_KEY) if not issuer: issuer = os.environ.get(AZURE_SERVICEBUS_ISSUER) if shared_access_key_name and shared_access_key_value: self.authentication = ServiceBusSASAuthentication( shared_access_key_name, shared_access_key_value) elif account_key and issuer: self.authentication = ServiceBusWrapTokenAuthentication( account_key, issuer) else: raise WindowsAzureError( 'You need to provide servicebus access key and Issuer OR shared access key and value') self._httpclient = _HTTPClient(service_instance=self) self._filter = self._httpclient.perform_request # Backwards compatibility: # account_key and issuer used to be stored on the service class, they are # now stored on the authentication class. @property def account_key(self): return self.authentication.account_key @account_key.setter def account_key(self, value): self.authentication.account_key = value @property def issuer(self): return self.authentication.issuer @issuer.setter def issuer(self, value): self.authentication.issuer = value def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = ServiceBusService( service_namespace=self.service_namespace, authentication=self.authentication) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def create_queue(self, queue_name, queue=None, fail_on_exist=False): ''' Creates a new queue. Once created, this queue's resource manifest is immutable. queue_name: Name of the queue to create. queue: Queue object to create. fail_on_exist: Specify whether to throw an exception when the queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.body = _get_request_body(_convert_queue_to_xml(queue)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Deletes an existing queue. This operation will also remove all associated state including messages in the queue. queue_name: Name of the queue to delete. fail_not_exist: Specify whether to throw an exception if the queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue(self, queue_name): ''' Retrieves an existing queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_queue(response) def list_queues(self): ''' Enumerates the queues in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Queues' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_queue) def create_topic(self, topic_name, topic=None, fail_on_exist=False): ''' Creates a new topic. Once created, this topic resource manifest is immutable. topic_name: Name of the topic to create. topic: Topic object to create. fail_on_exist: Specify whether to throw an exception when the topic exists. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.body = _get_request_body(_convert_topic_to_xml(topic)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_topic(self, topic_name, fail_not_exist=False): ''' Deletes an existing topic. This operation will also remove all associated state including associated subscriptions. topic_name: Name of the topic to delete. fail_not_exist: Specify whether throw exception when topic doesn't exist. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_topic(self, topic_name): ''' Retrieves the description for the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_topic(response) def list_topics(self): ''' Retrieves the topics in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Topics' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_topic) def create_rule(self, topic_name, subscription_name, rule_name, rule=None, fail_on_exist=False): ''' Creates a new rule. Once created, this rule's resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. fail_on_exist: Specify whether to throw an exception when the rule exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.body = _get_request_body(_convert_rule_to_xml(rule)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_rule(self, topic_name, subscription_name, rule_name, fail_not_exist=False): ''' Deletes an existing rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule to delete. DEFAULT_RULE_NAME=$Default. Use DEFAULT_RULE_NAME to delete default rule for the subscription. fail_not_exist: Specify whether throw exception when rule doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_rule(self, topic_name, subscription_name, rule_name): ''' Retrieves the description for the specified rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_rule(response) def list_rules(self, topic_name, subscription_name): ''' Retrieves the rules that exist under the specified subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/rules/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_rule) def create_subscription(self, topic_name, subscription_name, subscription=None, fail_on_exist=False): ''' Creates a new subscription. Once created, this subscription resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. fail_on_exist: Specify whether throw exception when subscription exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.body = _get_request_body( _convert_subscription_to_xml(subscription)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_subscription(self, topic_name, subscription_name, fail_not_exist=False): ''' Deletes an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription to delete. fail_not_exist: Specify whether to throw an exception when the subscription doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_subscription(self, topic_name, subscription_name): ''' Gets an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_subscription(response) def list_subscriptions(self, topic_name): ''' Retrieves the subscriptions in the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_subscription) def send_topic_message(self, topic_name, message=None): ''' Enqueues a message into the specified topic. The limit to the number of messages which may be present in the topic is governed by the message size in MaxTopicSizeInBytes. If this message causes the topic to exceed its quota, a quota exceeded error is returned and the message will be rejected. topic_name: Name of the topic. message: Message object containing message body and properties. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only( 'message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' This operation is used to atomically retrieve and lock a message for processing. The message is guaranteed not to be delivered to other receivers during the lock duration period specified in buffer description. Once the lock expires, the message will be available to other receivers (on the same subscription only) during the lock duration period specified in the topic description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Unlock a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' Read and delete a message from a subscription as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the subscription. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def send_queue_message(self, queue_name, message=None): ''' Sends a message into the specified queue. The limit to the number of messages which may be present in the topic is governed by the message size the MaxTopicSizeInMegaBytes. If this message will cause the queue to exceed its quota, a quota exceeded error is returned and the message will be rejected. queue_name: Name of the queue. message: Message object containing message body and properties. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only('message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_queue_message(self, queue_name, timeout='60'): ''' Automically retrieves and locks a message from a queue for processing. The message is guaranteed not to be delivered to other receivers (on the same subscription only) during the lock duration period specified in the queue description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_queue_message(self, queue_name, sequence_number, lock_token): ''' Unlocks a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. queue_name: Name of the queue. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_queue_message(self, queue_name, timeout='60'): ''' Reads and deletes a message from a queue as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_queue_message(self, queue_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the queue. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. queue_name: Name of the queue. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def receive_queue_message(self, queue_name, peek_lock=True, timeout=60): ''' Receive a message from a queue for processing. queue_name: Name of the queue. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_queue_message(queue_name, timeout) else: return self.read_delete_queue_message(queue_name, timeout) def receive_subscription_message(self, topic_name, subscription_name, peek_lock=True, timeout=60): ''' Receive a message from a subscription for processing. topic_name: Name of the topic. subscription_name: Name of the subscription. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_subscription_message(topic_name, subscription_name, timeout) else: return self.read_delete_subscription_message(topic_name, subscription_name, timeout) def _get_host(self): return self.service_namespace + self.host_base def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _service_bus_error_handler(ex) return resp def _update_service_bus_header(self, request): ''' Add additional headers for service bus. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) # Adds authorization header for authentication. self.authentication.sign_request(request, self._httpclient) return request.headers # Token cache for Authentication # Shared by the different instances of ServiceBusWrapTokenAuthentication _tokens = {} class ServiceBusWrapTokenAuthentication: def __init__(self, account_key, issuer): self.account_key = account_key self.issuer = issuer def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): ''' return the signed string with token. ''' return 'WRAP access_token="' + \ self._get_token(request.host, request.path, httpclient) + '"' def _token_is_expired(self, token): ''' Check if token expires or not. ''' time_pos_begin = token.find('ExpiresOn=') + len('ExpiresOn=') time_pos_end = token.find('&', time_pos_begin) token_expire_time = int(token[time_pos_begin:time_pos_end]) time_now = time.mktime(time.localtime()) # Adding 30 seconds so the token wouldn't be expired when we send the # token to server. return (token_expire_time - time_now) < 30 def _get_token(self, host, path, httpclient): ''' Returns token for the request. host: the service bus service request. path: the service bus service request. ''' wrap_scope = 'http://' + host + path + self.issuer + self.account_key # Check whether has unexpired cache, return cached token if it is still # usable. if wrap_scope in _tokens: token = _tokens[wrap_scope] if not self._token_is_expired(token): return token # get token from accessconstrol server request = HTTPRequest() request.protocol_override = 'https' request.host = host.replace('.servicebus.', '-sb.accesscontrol.') request.method = 'POST' request.path = '/WRAPv0.9' request.body = ('wrap_name=' + url_quote(self.issuer) + '&wrap_password=' + url_quote(self.account_key) + '&wrap_scope=' + url_quote('http://' + host + path)).encode('utf-8') request.headers.append(('Content-Length', str(len(request.body)))) resp = httpclient.perform_request(request) token = resp.body.decode('utf-8') token = url_unquote(token[token.find('=') + 1:token.rfind('&')]) _tokens[wrap_scope] = token return token class ServiceBusSASAuthentication: def __init__(self, key_name, key_value): self.key_name = key_name self.key_value = key_value def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): uri = httpclient.get_uri(request) uri = url_quote(uri, '').lower() expiry = str(self._get_expiry()) to_sign = uri + '\n' + expiry signature = url_quote(_sign_string(self.key_value, to_sign, False), '') auth_format = 'SharedAccessSignature sig={0}&se={1}&skn={2}&sr={3}' auth = auth_format.format(signature, expiry, self.key_name, uri) return auth def _get_expiry(self): '''Returns the UTC datetime, in seconds since Epoch, when this signed request expires (5 minutes from now).''' return int(round(time.time() + 300)) ================================================ FILE: DSC/azure/servicemanagement/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from xml.dom import minidom from azure import ( WindowsAzureData, _Base64String, _create_entry, _dict_of, _encode_base64, _general_error_handler, _get_children_from_path, _get_first_child_node_value, _list_of, _scalar_list_of, _str, _xml_attribute, ) #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_MANAGEMENT_CERTFILE = 'AZURE_MANAGEMENT_CERTFILE' AZURE_MANAGEMENT_SUBSCRIPTIONID = 'AZURE_MANAGEMENT_SUBSCRIPTIONID' # x-ms-version for service management. X_MS_VERSION = '2013-06-01' #----------------------------------------------------------------------------- # Data classes class StorageServices(WindowsAzureData): def __init__(self): self.storage_services = _list_of(StorageService) def __iter__(self): return iter(self.storage_services) def __len__(self): return len(self.storage_services) def __getitem__(self, index): return self.storage_services[index] class StorageService(WindowsAzureData): def __init__(self): self.url = '' self.service_name = '' self.storage_service_properties = StorageAccountProperties() self.storage_service_keys = StorageServiceKeys() self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') self.capabilities = _scalar_list_of(str, 'Capability') class StorageAccountProperties(WindowsAzureData): def __init__(self): self.description = u'' self.affinity_group = u'' self.location = u'' self.label = _Base64String() self.status = u'' self.endpoints = _scalar_list_of(str, 'Endpoint') self.geo_replication_enabled = False self.geo_primary_region = u'' self.status_of_primary = u'' self.geo_secondary_region = u'' self.status_of_secondary = u'' self.last_geo_failover_time = u'' self.creation_time = u'' class StorageServiceKeys(WindowsAzureData): def __init__(self): self.primary = u'' self.secondary = u'' class Locations(WindowsAzureData): def __init__(self): self.locations = _list_of(Location) def __iter__(self): return iter(self.locations) def __len__(self): return len(self.locations) def __getitem__(self, index): return self.locations[index] class Location(WindowsAzureData): def __init__(self): self.name = u'' self.display_name = u'' self.available_services = _scalar_list_of(str, 'AvailableService') class AffinityGroup(WindowsAzureData): def __init__(self): self.name = '' self.label = _Base64String() self.description = u'' self.location = u'' self.hosted_services = HostedServices() self.storage_services = StorageServices() self.capabilities = _scalar_list_of(str, 'Capability') class AffinityGroups(WindowsAzureData): def __init__(self): self.affinity_groups = _list_of(AffinityGroup) def __iter__(self): return iter(self.affinity_groups) def __len__(self): return len(self.affinity_groups) def __getitem__(self, index): return self.affinity_groups[index] class HostedServices(WindowsAzureData): def __init__(self): self.hosted_services = _list_of(HostedService) def __iter__(self): return iter(self.hosted_services) def __len__(self): return len(self.hosted_services) def __getitem__(self, index): return self.hosted_services[index] class HostedService(WindowsAzureData): def __init__(self): self.url = u'' self.service_name = u'' self.hosted_service_properties = HostedServiceProperties() self.deployments = Deployments() class HostedServiceProperties(WindowsAzureData): def __init__(self): self.description = u'' self.location = u'' self.affinity_group = u'' self.label = _Base64String() self.status = u'' self.date_created = u'' self.date_last_modified = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class VirtualNetworkSites(WindowsAzureData): def __init__(self): self.virtual_network_sites = _list_of(VirtualNetworkSite) def __iter__(self): return iter(self.virtual_network_sites) def __len__(self): return len(self.virtual_network_sites) def __getitem__(self, index): return self.virtual_network_sites[index] class VirtualNetworkSite(WindowsAzureData): def __init__(self): self.name = u'' self.id = u'' self.affinity_group = u'' self.subnets = Subnets() class Subnets(WindowsAzureData): def __init__(self): self.subnets = _list_of(Subnet) def __iter__(self): return iter(self.subnets) def __len__(self): return len(self.subnets) def __getitem__(self, index): return self.subnets[index] class Subnet(WindowsAzureData): def __init__(self): self.name = u'' self.address_prefix = u'' class Deployments(WindowsAzureData): def __init__(self): self.deployments = _list_of(Deployment) def __iter__(self): return iter(self.deployments) def __len__(self): return len(self.deployments) def __getitem__(self, index): return self.deployments[index] class Deployment(WindowsAzureData): def __init__(self): self.name = u'' self.deployment_slot = u'' self.private_id = u'' self.status = u'' self.label = _Base64String() self.url = u'' self.configuration = _Base64String() self.role_instance_list = RoleInstanceList() self.upgrade_status = UpgradeStatus() self.upgrade_domain_count = u'' self.role_list = RoleList() self.sdk_version = u'' self.input_endpoint_list = InputEndpoints() self.locked = False self.rollback_allowed = False self.persistent_vm_downtime_info = PersistentVMDowntimeInfo() self.created_time = u'' self.virtual_network_name = u'' self.last_modified_time = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class RoleInstanceList(WindowsAzureData): def __init__(self): self.role_instances = _list_of(RoleInstance) def __iter__(self): return iter(self.role_instances) def __len__(self): return len(self.role_instances) def __getitem__(self, index): return self.role_instances[index] class RoleInstance(WindowsAzureData): def __init__(self): self.role_name = u'' self.instance_name = u'' self.instance_status = u'' self.instance_upgrade_domain = 0 self.instance_fault_domain = 0 self.instance_size = u'' self.instance_state_details = u'' self.instance_error_code = u'' self.ip_address = u'' self.instance_endpoints = InstanceEndpoints() self.power_state = u'' self.fqdn = u'' self.host_name = u'' class InstanceEndpoints(WindowsAzureData): def __init__(self): self.instance_endpoints = _list_of(InstanceEndpoint) def __iter__(self): return iter(self.instance_endpoints) def __len__(self): return len(self.instance_endpoints) def __getitem__(self, index): return self.instance_endpoints[index] class InstanceEndpoint(WindowsAzureData): def __init__(self): self.name = u'' self.vip = u'' self.public_port = u'' self.local_port = u'' self.protocol = u'' class UpgradeStatus(WindowsAzureData): def __init__(self): self.upgrade_type = u'' self.current_upgrade_domain_state = u'' self.current_upgrade_domain = u'' class InputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of(InputEndpoint) def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class InputEndpoint(WindowsAzureData): def __init__(self): self.role_name = u'' self.vip = u'' self.port = u'' class RoleList(WindowsAzureData): def __init__(self): self.roles = _list_of(Role) def __iter__(self): return iter(self.roles) def __len__(self): return len(self.roles) def __getitem__(self, index): return self.roles[index] class Role(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class PersistentVMDowntimeInfo(WindowsAzureData): def __init__(self): self.start_time = u'' self.end_time = u'' self.status = u'' class Certificates(WindowsAzureData): def __init__(self): self.certificates = _list_of(Certificate) def __iter__(self): return iter(self.certificates) def __len__(self): return len(self.certificates) def __getitem__(self, index): return self.certificates[index] class Certificate(WindowsAzureData): def __init__(self): self.certificate_url = u'' self.thumbprint = u'' self.thumbprint_algorithm = u'' self.data = u'' class OperationError(WindowsAzureData): def __init__(self): self.code = u'' self.message = u'' class Operation(WindowsAzureData): def __init__(self): self.id = u'' self.status = u'' self.http_status_code = u'' self.error = OperationError() class OperatingSystem(WindowsAzureData): def __init__(self): self.version = u'' self.label = _Base64String() self.is_default = True self.is_active = True self.family = 0 self.family_label = _Base64String() class OperatingSystems(WindowsAzureData): def __init__(self): self.operating_systems = _list_of(OperatingSystem) def __iter__(self): return iter(self.operating_systems) def __len__(self): return len(self.operating_systems) def __getitem__(self, index): return self.operating_systems[index] class OperatingSystemFamily(WindowsAzureData): def __init__(self): self.name = u'' self.label = _Base64String() self.operating_systems = OperatingSystems() class OperatingSystemFamilies(WindowsAzureData): def __init__(self): self.operating_system_families = _list_of(OperatingSystemFamily) def __iter__(self): return iter(self.operating_system_families) def __len__(self): return len(self.operating_system_families) def __getitem__(self, index): return self.operating_system_families[index] class Subscription(WindowsAzureData): def __init__(self): self.subscription_id = u'' self.subscription_name = u'' self.subscription_status = u'' self.account_admin_live_email_id = u'' self.service_admin_live_email_id = u'' self.max_core_count = 0 self.max_storage_accounts = 0 self.max_hosted_services = 0 self.current_core_count = 0 self.current_hosted_services = 0 self.current_storage_accounts = 0 self.max_virtual_network_sites = 0 self.max_local_network_sites = 0 self.max_dns_servers = 0 class AvailabilityResponse(WindowsAzureData): def __init__(self): self.result = False class SubscriptionCertificates(WindowsAzureData): def __init__(self): self.subscription_certificates = _list_of(SubscriptionCertificate) def __iter__(self): return iter(self.subscription_certificates) def __len__(self): return len(self.subscription_certificates) def __getitem__(self, index): return self.subscription_certificates[index] class SubscriptionCertificate(WindowsAzureData): def __init__(self): self.subscription_certificate_public_key = u'' self.subscription_certificate_thumbprint = u'' self.subscription_certificate_data = u'' self.created = u'' class Images(WindowsAzureData): def __init__(self): self.images = _list_of(OSImage) def __iter__(self): return iter(self.images) def __len__(self): return len(self.images) def __getitem__(self, index): return self.images[index] class OSImage(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.category = u'' self.location = u'' self.logical_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.eula = u'' self.description = u'' class Disks(WindowsAzureData): def __init__(self): self.disks = _list_of(Disk) def __iter__(self): return iter(self.disks) def __len__(self): return len(self.disks) def __getitem__(self, index): return self.disks[index] class Disk(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.attached_to = AttachedTo() self.has_operating_system = u'' self.is_corrupted = u'' self.location = u'' self.logical_disk_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.source_image_name = u'' class AttachedTo(WindowsAzureData): def __init__(self): self.hosted_service_name = u'' self.deployment_name = u'' self.role_name = u'' class PersistentVMRole(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' # undocumented self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class ConfigurationSets(WindowsAzureData): def __init__(self): self.configuration_sets = _list_of(ConfigurationSet) def __iter__(self): return iter(self.configuration_sets) def __len__(self): return len(self.configuration_sets) def __getitem__(self, index): return self.configuration_sets[index] class ConfigurationSet(WindowsAzureData): def __init__(self): self.configuration_set_type = u'NetworkConfiguration' self.role_type = u'' self.input_endpoints = ConfigurationSetInputEndpoints() self.subnet_names = _scalar_list_of(str, 'SubnetName') class ConfigurationSetInputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of( ConfigurationSetInputEndpoint, 'InputEndpoint') def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class ConfigurationSetInputEndpoint(WindowsAzureData): ''' Initializes a network configuration input endpoint. name: Specifies the name for the external endpoint. protocol: Specifies the protocol to use to inspect the virtual machine availability status. Possible values are: HTTP, TCP. port: Specifies the external port to use for the endpoint. local_port: Specifies the internal port on which the virtual machine is listening to serve the endpoint. load_balanced_endpoint_set_name: Specifies a name for a set of load-balanced endpoints. Specifying this element for a given endpoint adds it to the set. If you are setting an endpoint to use to connect to the virtual machine via the Remote Desktop, do not set this property. enable_direct_server_return: Specifies whether direct server return load balancing is enabled. ''' def __init__(self, name=u'', protocol=u'', port=u'', local_port=u'', load_balanced_endpoint_set_name=u'', enable_direct_server_return=False): self.enable_direct_server_return = enable_direct_server_return self.load_balanced_endpoint_set_name = load_balanced_endpoint_set_name self.local_port = local_port self.name = name self.port = port self.load_balancer_probe = LoadBalancerProbe() self.protocol = protocol class WindowsConfigurationSet(WindowsAzureData): def __init__(self, computer_name=None, admin_password=None, reset_password_on_first_logon=None, enable_automatic_updates=None, time_zone=None, admin_username=None): self.configuration_set_type = u'WindowsProvisioningConfiguration' self.computer_name = computer_name self.admin_password = admin_password self.admin_username = admin_username self.reset_password_on_first_logon = reset_password_on_first_logon self.enable_automatic_updates = enable_automatic_updates self.time_zone = time_zone self.domain_join = DomainJoin() self.stored_certificate_settings = StoredCertificateSettings() self.win_rm = WinRM() class DomainJoin(WindowsAzureData): def __init__(self): self.credentials = Credentials() self.join_domain = u'' self.machine_object_ou = u'' class Credentials(WindowsAzureData): def __init__(self): self.domain = u'' self.username = u'' self.password = u'' class StoredCertificateSettings(WindowsAzureData): def __init__(self): self.stored_certificate_settings = _list_of(CertificateSetting) def __iter__(self): return iter(self.stored_certificate_settings) def __len__(self): return len(self.stored_certificate_settings) def __getitem__(self, index): return self.stored_certificate_settings[index] class CertificateSetting(WindowsAzureData): ''' Initializes a certificate setting. thumbprint: Specifies the thumbprint of the certificate to be provisioned. The thumbprint must specify an existing service certificate. store_name: Specifies the name of the certificate store from which retrieve certificate. store_location: Specifies the target certificate store location on the virtual machine. The only supported value is LocalMachine. ''' def __init__(self, thumbprint=u'', store_name=u'', store_location=u''): self.thumbprint = thumbprint self.store_name = store_name self.store_location = store_location class WinRM(WindowsAzureData): ''' Contains configuration settings for the Windows Remote Management service on the Virtual Machine. ''' def __init__(self): self.listeners = Listeners() class Listeners(WindowsAzureData): def __init__(self): self.listeners = _list_of(Listener) def __iter__(self): return iter(self.listeners) def __len__(self): return len(self.listeners) def __getitem__(self, index): return self.listeners[index] class Listener(WindowsAzureData): ''' Specifies the protocol and certificate information for the listener. protocol: Specifies the protocol of listener. Possible values are: Http, Https. The value is case sensitive. certificate_thumbprint: Optional. Specifies the certificate thumbprint for the secure connection. If this value is not specified, a self-signed certificate is generated and used for the Virtual Machine. ''' def __init__(self, protocol=u'', certificate_thumbprint=u''): self.protocol = protocol self.certificate_thumbprint = certificate_thumbprint class LinuxConfigurationSet(WindowsAzureData): def __init__(self, host_name=None, user_name=None, user_password=None, disable_ssh_password_authentication=None): self.configuration_set_type = u'LinuxProvisioningConfiguration' self.host_name = host_name self.user_name = user_name self.user_password = user_password self.disable_ssh_password_authentication =\ disable_ssh_password_authentication self.ssh = SSH() class SSH(WindowsAzureData): def __init__(self): self.public_keys = PublicKeys() self.key_pairs = KeyPairs() class PublicKeys(WindowsAzureData): def __init__(self): self.public_keys = _list_of(PublicKey) def __iter__(self): return iter(self.public_keys) def __len__(self): return len(self.public_keys) def __getitem__(self, index): return self.public_keys[index] class PublicKey(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class KeyPairs(WindowsAzureData): def __init__(self): self.key_pairs = _list_of(KeyPair) def __iter__(self): return iter(self.key_pairs) def __len__(self): return len(self.key_pairs) def __getitem__(self, index): return self.key_pairs[index] class KeyPair(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class LoadBalancerProbe(WindowsAzureData): def __init__(self): self.path = u'' self.port = u'' self.protocol = u'' class DataVirtualHardDisks(WindowsAzureData): def __init__(self): self.data_virtual_hard_disks = _list_of(DataVirtualHardDisk) def __iter__(self): return iter(self.data_virtual_hard_disks) def __len__(self): return len(self.data_virtual_hard_disks) def __getitem__(self, index): return self.data_virtual_hard_disks[index] class DataVirtualHardDisk(WindowsAzureData): def __init__(self): self.host_caching = u'' self.disk_label = u'' self.disk_name = u'' self.lun = 0 self.logical_disk_size_in_gb = 0 self.media_link = u'' class OSVirtualHardDisk(WindowsAzureData): def __init__(self, source_image_name=None, media_link=None, host_caching=None, disk_label=None, disk_name=None): self.source_image_name = source_image_name self.media_link = media_link self.host_caching = host_caching self.disk_label = disk_label self.disk_name = disk_name self.os = u'' # undocumented, not used when adding a role class AsynchronousOperationResult(WindowsAzureData): def __init__(self, request_id=None): self.request_id = request_id class ServiceBusRegion(WindowsAzureData): def __init__(self): self.code = u'' self.fullname = u'' class ServiceBusNamespace(WindowsAzureData): def __init__(self): self.name = u'' self.region = u'' self.default_key = u'' self.status = u'' self.created_at = u'' self.acs_management_endpoint = u'' self.servicebus_endpoint = u'' self.connection_string = u'' self.subscription_id = u'' self.enabled = False class WebSpaces(WindowsAzureData): def __init__(self): self.web_space = _list_of(WebSpace) def __iter__(self): return iter(self.web_space) def __len__(self): return len(self.web_space) def __getitem__(self, index): return self.web_space[index] class WebSpace(WindowsAzureData): def __init__(self): self.availability_state = u'' self.geo_location = u'' self.geo_region = u'' self.name = u'' self.plan = u'' self.status = u'' self.subscription = u'' class Sites(WindowsAzureData): def __init__(self): self.site = _list_of(Site) def __iter__(self): return iter(self.site) def __len__(self): return len(self.site) def __getitem__(self, index): return self.site[index] class Site(WindowsAzureData): def __init__(self): self.admin_enabled = False self.availability_state = '' self.compute_mode = '' self.enabled = False self.enabled_host_names = _scalar_list_of(str, 'a:string') self.host_name_ssl_states = HostNameSslStates() self.host_names = _scalar_list_of(str, 'a:string') self.last_modified_time_utc = '' self.name = '' self.repository_site_name = '' self.self_link = '' self.server_farm = '' self.site_mode = '' self.state = '' self.storage_recovery_default_state = '' self.usage_state = '' self.web_space = '' class HostNameSslStates(WindowsAzureData): def __init__(self): self.host_name_ssl_state = _list_of(HostNameSslState) def __iter__(self): return iter(self.host_name_ssl_state) def __len__(self): return len(self.host_name_ssl_state) def __getitem__(self, index): return self.host_name_ssl_state[index] class HostNameSslState(WindowsAzureData): def __init__(self): self.name = u'' self.ssl_state = u'' class PublishData(WindowsAzureData): _xml_name = 'publishData' def __init__(self): self.publish_profiles = _list_of(PublishProfile, 'publishProfile') class PublishProfile(WindowsAzureData): def __init__(self): self.profile_name = _xml_attribute('profileName') self.publish_method = _xml_attribute('publishMethod') self.publish_url = _xml_attribute('publishUrl') self.msdeploysite = _xml_attribute('msdeploySite') self.user_name = _xml_attribute('userName') self.user_pwd = _xml_attribute('userPWD') self.destination_app_url = _xml_attribute('destinationAppUrl') self.sql_server_db_connection_string = _xml_attribute('SQLServerDBConnectionString') self.my_sqldb_connection_string = _xml_attribute('mySQLDBConnectionString') self.hosting_provider_forum_link = _xml_attribute('hostingProviderForumLink') self.control_panel_link = _xml_attribute('controlPanelLink') class QueueDescription(WindowsAzureData): def __init__(self): self.lock_duration = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.requires_session = False self.default_message_time_to_live = u'' self.dead_lettering_on_message_expiration = False self.duplicate_detection_history_time_window = u'' self.max_delivery_count = 0 self.enable_batched_operations = False self.size_in_bytes = 0 self.message_count = 0 self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.auto_delete_on_idle = u'' self.count_details = CountDetails() self.entity_availability_status = u'' class TopicDescription(WindowsAzureData): def __init__(self): self.default_message_time_to_live = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.duplicate_detection_history_time_window = u'' self.enable_batched_operations = False self.size_in_bytes = 0 self.filtering_messages_before_publishing = False self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.count_details = CountDetails() self.subscription_count = 0 class CountDetails(WindowsAzureData): def __init__(self): self.active_message_count = 0 self.dead_letter_message_count = 0 self.scheduled_message_count = 0 self.transfer_message_count = 0 self.transfer_dead_letter_message_count = 0 class NotificationHubDescription(WindowsAzureData): def __init__(self): self.registration_ttl = u'' self.authorization_rules = AuthorizationRules() class AuthorizationRules(WindowsAzureData): def __init__(self): self.authorization_rule = _list_of(AuthorizationRule) def __iter__(self): return iter(self.authorization_rule) def __len__(self): return len(self.authorization_rule) def __getitem__(self, index): return self.authorization_rule[index] class AuthorizationRule(WindowsAzureData): def __init__(self): self.claim_type = u'' self.claim_value = u'' self.rights = _scalar_list_of(str, 'AccessRights') self.created_time = u'' self.modified_time = u'' self.key_name = u'' self.primary_key = u'' self.secondary_keu = u'' class RelayDescription(WindowsAzureData): def __init__(self): self.path = u'' self.listener_type = u'' self.listener_count = 0 self.created_at = u'' self.updated_at = u'' class MetricResponses(WindowsAzureData): def __init__(self): self.metric_response = _list_of(MetricResponse) def __iter__(self): return iter(self.metric_response) def __len__(self): return len(self.metric_response) def __getitem__(self, index): return self.metric_response[index] class MetricResponse(WindowsAzureData): def __init__(self): self.code = u'' self.data = Data() self.message = u'' class Data(WindowsAzureData): def __init__(self): self.display_name = u'' self.end_time = u'' self.name = u'' self.primary_aggregation_type = u'' self.start_time = u'' self.time_grain = u'' self.unit = u'' self.values = Values() class Values(WindowsAzureData): def __init__(self): self.metric_sample = _list_of(MetricSample) def __iter__(self): return iter(self.metric_sample) def __len__(self): return len(self.metric_sample) def __getitem__(self, index): return self.metric_sample[index] class MetricSample(WindowsAzureData): def __init__(self): self.count = 0 self.time_created = u'' self.total = 0 class MetricDefinitions(WindowsAzureData): def __init__(self): self.metric_definition = _list_of(MetricDefinition) def __iter__(self): return iter(self.metric_definition) def __len__(self): return len(self.metric_definition) def __getitem__(self, index): return self.metric_definition[index] class MetricDefinition(WindowsAzureData): def __init__(self): self.display_name = u'' self.metric_availabilities = MetricAvailabilities() self.name = u'' self.primary_aggregation_type = u'' self.unit = u'' class MetricAvailabilities(WindowsAzureData): def __init__(self): self.metric_availability = _list_of(MetricAvailability, 'MetricAvailabilily') def __iter__(self): return iter(self.metric_availability) def __len__(self): return len(self.metric_availability) def __getitem__(self, index): return self.metric_availability[index] class MetricAvailability(WindowsAzureData): def __init__(self): self.retention = u'' self.time_grain = u'' class Servers(WindowsAzureData): def __init__(self): self.server = _list_of(Server) def __iter__(self): return iter(self.server) def __len__(self): return len(self.server) def __getitem__(self, index): return self.server[index] class Server(WindowsAzureData): def __init__(self): self.name = u'' self.administrator_login = u'' self.location = u'' self.fully_qualified_domain_name = u'' self.version = u'' class Database(WindowsAzureData): def __init__(self): self.name = u'' self.type = u'' self.state = u'' self.self_link = u'' self.parent_link = u'' self.id = 0 self.edition = u'' self.collation_name = u'' self.creation_date = u'' self.is_federation_root = False self.is_system_object = False self.max_size_bytes = 0 def _update_management_header(request): ''' Add additional headers for management. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append additional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) return request.headers def _parse_response_for_async_op(response): ''' Extracts request id from response header. ''' if response is None: return None result = AsynchronousOperationResult() if response.headers: for name, value in response.headers: if name.lower() == 'x-ms-request-id': result.request_id = value return result def _management_error_handler(http_error): ''' Simple error handler for management service. ''' return _general_error_handler(http_error) def _lower(text): return text.lower() class _XmlSerializer(object): @staticmethod def create_storage_service_input_to_xml(service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'CreateStorageServiceInput', [('ServiceName', service_name), ('Description', description), ('Label', label, _encode_base64), ('AffinityGroup', affinity_group), ('Location', location), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def update_storage_service_input_to_xml(description, label, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'UpdateStorageServiceInput', [('Description', description), ('Label', label, _encode_base64), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def regenerate_keys_to_xml(key_type): return _XmlSerializer.doc_from_data('RegenerateKeys', [('KeyType', key_type)]) @staticmethod def update_hosted_service_to_xml(label, description, extended_properties): return _XmlSerializer.doc_from_data('UpdateHostedService', [('Label', label, _encode_base64), ('Description', description)], extended_properties) @staticmethod def create_hosted_service_to_xml(service_name, label, description, location, affinity_group, extended_properties): return _XmlSerializer.doc_from_data( 'CreateHostedService', [('ServiceName', service_name), ('Label', label, _encode_base64), ('Description', description), ('Location', location), ('AffinityGroup', affinity_group)], extended_properties) @staticmethod def create_deployment_to_xml(name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties): return _XmlSerializer.doc_from_data( 'CreateDeployment', [('Name', name), ('PackageUrl', package_url), ('Label', label, _encode_base64), ('Configuration', configuration), ('StartDeployment', start_deployment, _lower), ('TreatWarningsAsError', treat_warnings_as_error, _lower)], extended_properties) @staticmethod def swap_deployment_to_xml(production, source_deployment): return _XmlSerializer.doc_from_data( 'Swap', [('Production', production), ('SourceDeployment', source_deployment)]) @staticmethod def update_deployment_status_to_xml(status): return _XmlSerializer.doc_from_data( 'UpdateDeploymentStatus', [('Status', status)]) @staticmethod def change_deployment_to_xml(configuration, treat_warnings_as_error, mode, extended_properties): return _XmlSerializer.doc_from_data( 'ChangeConfiguration', [('Configuration', configuration), ('TreatWarningsAsError', treat_warnings_as_error, _lower), ('Mode', mode)], extended_properties) @staticmethod def upgrade_deployment_to_xml(mode, package_url, configuration, label, role_to_upgrade, force, extended_properties): return _XmlSerializer.doc_from_data( 'UpgradeDeployment', [('Mode', mode), ('PackageUrl', package_url), ('Configuration', configuration), ('Label', label, _encode_base64), ('RoleToUpgrade', role_to_upgrade), ('Force', force, _lower)], extended_properties) @staticmethod def rollback_upgrade_to_xml(mode, force): return _XmlSerializer.doc_from_data( 'RollbackUpdateOrUpgrade', [('Mode', mode), ('Force', force, _lower)]) @staticmethod def walk_upgrade_domain_to_xml(upgrade_domain): return _XmlSerializer.doc_from_data( 'WalkUpgradeDomain', [('UpgradeDomain', upgrade_domain)]) @staticmethod def certificate_file_to_xml(data, certificate_format, password): return _XmlSerializer.doc_from_data( 'CertificateFile', [('Data', data), ('CertificateFormat', certificate_format), ('Password', password)]) @staticmethod def create_affinity_group_to_xml(name, label, description, location): return _XmlSerializer.doc_from_data( 'CreateAffinityGroup', [('Name', name), ('Label', label, _encode_base64), ('Description', description), ('Location', location)]) @staticmethod def update_affinity_group_to_xml(label, description): return _XmlSerializer.doc_from_data( 'UpdateAffinityGroup', [('Label', label, _encode_base64), ('Description', description)]) @staticmethod def subscription_certificate_to_xml(public_key, thumbprint, data): return _XmlSerializer.doc_from_data( 'SubscriptionCertificate', [('SubscriptionCertificatePublicKey', public_key), ('SubscriptionCertificateThumbprint', thumbprint), ('SubscriptionCertificateData', data)]) @staticmethod def os_image_to_xml(label, media_link, name, os): return _XmlSerializer.doc_from_data( 'OSImage', [('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def data_virtual_hard_disk_to_xml(host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link): return _XmlSerializer.doc_from_data( 'DataVirtualHardDisk', [('HostCaching', host_caching), ('DiskLabel', disk_label), ('DiskName', disk_name), ('Lun', lun), ('LogicalDiskSizeInGB', logical_disk_size_in_gb), ('MediaLink', media_link), ('SourceMediaLink', source_media_link)]) @staticmethod def disk_to_xml(has_operating_system, label, media_link, name, os): return _XmlSerializer.doc_from_data( 'Disk', [('HasOperatingSystem', has_operating_system, _lower), ('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def restart_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'RestartRoleOperation', '<OperationType>RestartRoleOperation</OperationType>') @staticmethod def shutdown_role_operation_to_xml(post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRoleOperation'), ('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRoleOperation', xml) @staticmethod def shutdown_roles_operation_to_xml(role_names, post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' xml += _XmlSerializer.data_to_xml( [('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRolesOperation', xml) @staticmethod def start_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'StartRoleOperation', '<OperationType>StartRoleOperation</OperationType>') @staticmethod def start_roles_operation_to_xml(role_names): xml = _XmlSerializer.data_to_xml( [('OperationType', 'StartRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' return _XmlSerializer.doc_from_xml('StartRolesOperation', xml) @staticmethod def windows_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('ComputerName', configuration.computer_name), ('AdminPassword', configuration.admin_password), ('ResetPasswordOnFirstLogon', configuration.reset_password_on_first_logon, _lower), ('EnableAutomaticUpdates', configuration.enable_automatic_updates, _lower), ('TimeZone', configuration.time_zone)]) if configuration.domain_join is not None: xml += '<DomainJoin>' xml += '<Credentials>' xml += _XmlSerializer.data_to_xml( [('Domain', configuration.domain_join.credentials.domain), ('Username', configuration.domain_join.credentials.username), ('Password', configuration.domain_join.credentials.password)]) xml += '</Credentials>' xml += _XmlSerializer.data_to_xml( [('JoinDomain', configuration.domain_join.join_domain), ('MachineObjectOU', configuration.domain_join.machine_object_ou)]) xml += '</DomainJoin>' if configuration.stored_certificate_settings is not None: xml += '<StoredCertificateSettings>' for cert in configuration.stored_certificate_settings: xml += '<CertificateSetting>' xml += _XmlSerializer.data_to_xml( [('StoreLocation', cert.store_location), ('StoreName', cert.store_name), ('Thumbprint', cert.thumbprint)]) xml += '</CertificateSetting>' xml += '</StoredCertificateSettings>' if configuration.win_rm is not None: xml += '<WinRM><Listeners>' for listener in configuration.win_rm.listeners: xml += '<Listener>' xml += _XmlSerializer.data_to_xml( [('Protocol', listener.protocol), ('CertificateThumbprint', listener.certificate_thumbprint)]) xml += '</Listener>' xml += '</Listeners></WinRM>' xml += _XmlSerializer.data_to_xml( [('AdminUsername', configuration.admin_username)]) return xml @staticmethod def linux_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('HostName', configuration.host_name), ('UserName', configuration.user_name), ('UserPassword', configuration.user_password), ('DisableSshPasswordAuthentication', configuration.disable_ssh_password_authentication, _lower)]) if configuration.ssh is not None: xml += '<SSH>' xml += '<PublicKeys>' for key in configuration.ssh.public_keys: xml += '<PublicKey>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</PublicKey>' xml += '</PublicKeys>' xml += '<KeyPairs>' for key in configuration.ssh.key_pairs: xml += '<KeyPair>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</KeyPair>' xml += '</KeyPairs>' xml += '</SSH>' return xml @staticmethod def network_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type)]) xml += '<InputEndpoints>' for endpoint in configuration.input_endpoints: xml += '<InputEndpoint>' xml += _XmlSerializer.data_to_xml( [('LoadBalancedEndpointSetName', endpoint.load_balanced_endpoint_set_name), ('LocalPort', endpoint.local_port), ('Name', endpoint.name), ('Port', endpoint.port)]) if endpoint.load_balancer_probe.path or\ endpoint.load_balancer_probe.port or\ endpoint.load_balancer_probe.protocol: xml += '<LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Path', endpoint.load_balancer_probe.path), ('Port', endpoint.load_balancer_probe.port), ('Protocol', endpoint.load_balancer_probe.protocol)]) xml += '</LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Protocol', endpoint.protocol), ('EnableDirectServerReturn', endpoint.enable_direct_server_return, _lower)]) xml += '</InputEndpoint>' xml += '</InputEndpoints>' xml += '<SubnetNames>' for name in configuration.subnet_names: xml += _XmlSerializer.data_to_xml([('SubnetName', name)]) xml += '</SubnetNames>' return xml @staticmethod def role_to_xml(availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set): xml = _XmlSerializer.data_to_xml([('RoleName', role_name), ('RoleType', role_type)]) xml += '<ConfigurationSets>' if system_configuration_set is not None: xml += '<ConfigurationSet>' if isinstance(system_configuration_set, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( system_configuration_set) elif isinstance(system_configuration_set, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( system_configuration_set) xml += '</ConfigurationSet>' if network_configuration_set is not None: xml += '<ConfigurationSet>' xml += _XmlSerializer.network_configuration_to_xml( network_configuration_set) xml += '</ConfigurationSet>' xml += '</ConfigurationSets>' if availability_set_name is not None: xml += _XmlSerializer.data_to_xml( [('AvailabilitySetName', availability_set_name)]) if data_virtual_hard_disks is not None: xml += '<DataVirtualHardDisks>' for hd in data_virtual_hard_disks: xml += '<DataVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', hd.host_caching), ('DiskLabel', hd.disk_label), ('DiskName', hd.disk_name), ('Lun', hd.lun), ('LogicalDiskSizeInGB', hd.logical_disk_size_in_gb), ('MediaLink', hd.media_link)]) xml += '</DataVirtualHardDisk>' xml += '</DataVirtualHardDisks>' if os_virtual_hard_disk is not None: xml += '<OSVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', os_virtual_hard_disk.host_caching), ('DiskLabel', os_virtual_hard_disk.disk_label), ('DiskName', os_virtual_hard_disk.disk_name), ('MediaLink', os_virtual_hard_disk.media_link), ('SourceImageName', os_virtual_hard_disk.source_image_name)]) xml += '</OSVirtualHardDisk>' if role_size is not None: xml += _XmlSerializer.data_to_xml([('RoleSize', role_size)]) return xml @staticmethod def add_role_to_xml(role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def update_role_to_xml(role_name, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, None) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def capture_role_to_xml(post_capture_action, target_image_name, target_image_label, provisioning_configuration): xml = _XmlSerializer.data_to_xml( [('OperationType', 'CaptureRoleOperation'), ('PostCaptureAction', post_capture_action)]) if provisioning_configuration is not None: xml += '<ProvisioningConfiguration>' if isinstance(provisioning_configuration, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( provisioning_configuration) elif isinstance(provisioning_configuration, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( provisioning_configuration) xml += '</ProvisioningConfiguration>' xml += _XmlSerializer.data_to_xml( [('TargetImageLabel', target_image_label), ('TargetImageName', target_image_name)]) return _XmlSerializer.doc_from_xml('CaptureRoleOperation', xml) @staticmethod def virtual_machine_deployment_to_xml(deployment_name, deployment_slot, label, role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name): xml = _XmlSerializer.data_to_xml([('Name', deployment_name), ('DeploymentSlot', deployment_slot), ('Label', label)]) xml += '<RoleList>' xml += '<Role>' xml += _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) xml += '</Role>' xml += '</RoleList>' if virtual_network_name is not None: xml += _XmlSerializer.data_to_xml( [('VirtualNetworkName', virtual_network_name)]) return _XmlSerializer.doc_from_xml('Deployment', xml) @staticmethod def create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode): xml = '<HostNames xmlns:a="http://schemas.microsoft.com/2003/10/Serialization/Arrays">' for host_name in host_names: xml += '<a:string>{0}</a:string>'.format(host_name) xml += '</HostNames>' xml += _XmlSerializer.data_to_xml( [('Name', website_name), ('ComputeMode', compute_mode), ('ServerFarm', server_farm), ('SiteMode', site_mode)]) xml += '<WebSpaceToCreate>' xml += _XmlSerializer.data_to_xml( [('GeoRegion', geo_region), ('Name', webspace_name), ('Plan', plan)]) xml += '</WebSpaceToCreate>' return _XmlSerializer.doc_from_xml('Site', xml) @staticmethod def data_to_xml(data): '''Creates an xml fragment from the specified data. data: Array of tuples, where first: xml element name second: xml element text third: conversion function ''' xml = '' for element in data: name = element[0] val = element[1] if len(element) > 2: converter = element[2] else: converter = None if val is not None: if converter is not None: text = _str(converter(_str(val))) else: text = _str(val) xml += ''.join(['<', name, '>', text, '</', name, '>']) return xml @staticmethod def doc_from_xml(document_element_name, inner_xml): '''Wraps the specified xml in an xml root element with default azure namespaces''' xml = ''.join(['<', document_element_name, ' xmlns:i="http://www.w3.org/2001/XMLSchema-instance"', ' xmlns="http://schemas.microsoft.com/windowsazure">']) xml += inner_xml xml += ''.join(['</', document_element_name, '>']) return xml @staticmethod def doc_from_data(document_element_name, data, extended_properties=None): xml = _XmlSerializer.data_to_xml(data) if extended_properties is not None: xml += _XmlSerializer.extended_properties_dict_to_xml_fragment( extended_properties) return _XmlSerializer.doc_from_xml(document_element_name, xml) @staticmethod def extended_properties_dict_to_xml_fragment(extended_properties): xml = '' if extended_properties is not None and len(extended_properties) > 0: xml += '<ExtendedProperties>' for key, val in extended_properties.items(): xml += ''.join(['<ExtendedProperty>', '<Name>', _str(key), '</Name>', '<Value>', _str(val), '</Value>', '</ExtendedProperty>']) xml += '</ExtendedProperties>' return xml def _parse_bool(value): if value.lower() == 'true': return True return False class _ServiceBusManagementXmlSerializer(object): @staticmethod def namespace_to_xml(region): '''Converts a service bus namespace description to xml The xml format: <?xml version="1.0" encoding="utf-8" standalone="yes"?> <entry xmlns="http://www.w3.org/2005/Atom"> <content type="application/xml"> <NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Region>West US</Region> </NamespaceDescription> </content> </entry> ''' body = '<NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' body += ''.join(['<Region>', region, '</Region>']) body += '</NamespaceDescription>' return _create_entry(body) @staticmethod def xml_to_namespace(xmlstr): '''Converts xml response to service bus namespace The xml format for namespace: <entry> <id>uuid:00000000-0000-0000-0000-000000000000;id=0000000</id> <title type="text">myunittests 2012-08-22T16:48:10Z myunittests West US 0000000000000000000000000000000000000000000= Active 2012-08-22T16:48:10.217Z https://myunittests-sb.accesscontrol.windows.net/ https://myunittests.servicebus.windows.net/ Endpoint=sb://myunittests.servicebus.windows.net/;SharedSecretIssuer=owner;SharedSecretValue=0000000000000000000000000000000000000000000= 00000000000000000000000000000000 true ''' xmldoc = minidom.parseString(xmlstr) namespace = ServiceBusNamespace() mappings = ( ('Name', 'name', None), ('Region', 'region', None), ('DefaultKey', 'default_key', None), ('Status', 'status', None), ('CreatedAt', 'created_at', None), ('AcsManagementEndpoint', 'acs_management_endpoint', None), ('ServiceBusEndpoint', 'servicebus_endpoint', None), ('ConnectionString', 'connection_string', None), ('SubscriptionId', 'subscription_id', None), ('Enabled', 'enabled', _parse_bool), ) for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceDescription'): for xml_name, field_name, conversion_func in mappings: node_value = _get_first_child_node_value(desc, xml_name) if node_value is not None: if conversion_func is not None: node_value = conversion_func(node_value) setattr(namespace, field_name, node_value) return namespace @staticmethod def xml_to_region(xmlstr): '''Converts xml response to service bus region The xml format for region: uuid:157c311f-081f-4b4a-a0ba-a8f990ffd2a3;id=1756759 2013-04-10T18:25:29Z East Asia East Asia ''' xmldoc = minidom.parseString(xmlstr) region = ServiceBusRegion() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RegionCodeDescription'): node_value = _get_first_child_node_value(desc, 'Code') if node_value is not None: region.code = node_value node_value = _get_first_child_node_value(desc, 'FullName') if node_value is not None: region.fullname = node_value return region @staticmethod def xml_to_namespace_availability(xmlstr): '''Converts xml response to service bus namespace availability The xml format: uuid:9fc7c652-1856-47ab-8d74-cd31502ea8e6;id=3683292 2013-04-16T03:03:37Z false ''' xmldoc = minidom.parseString(xmlstr) availability = AvailabilityResponse() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceAvailability'): node_value = _get_first_child_node_value(desc, 'Result') if node_value is not None: availability.result = _parse_bool(node_value) return availability from azure.servicemanagement.servicemanagementservice import ( ServiceManagementService) from azure.servicemanagement.servicebusmanagementservice import ( ServiceBusManagementService) from azure.servicemanagement.websitemanagementservice import ( WebsiteManagementService) ================================================ FILE: DSC/azure/servicemanagement/servicebusmanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _convert_response_to_feeds, _str, _validate_not_none, ) from azure.servicemanagement import ( _ServiceBusManagementXmlSerializer, QueueDescription, TopicDescription, NotificationHubDescription, RelayDescription, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceBusManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceBusManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for service bus ---------------------------------------- def get_regions(self): ''' Get list of available service bus regions. ''' response = self._perform_get( self._get_path('services/serviceBus/Regions/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_region) def list_namespaces(self): ''' List the service bus namespaces defined on the account. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_namespace) def get_namespace(self, name): ''' Get details about a specific namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces', name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace( response.body) def create_namespace(self, name, region): ''' Create a new service bus namespace. name: Name of the service bus namespace to create. region: Region to create the namespace in. ''' _validate_not_none('name', name) return self._perform_put( self._get_path('services/serviceBus/Namespaces', name), _ServiceBusManagementXmlSerializer.namespace_to_xml(region)) def delete_namespace(self, name): ''' Delete a service bus namespace. name: Name of the service bus namespace to delete. ''' _validate_not_none('name', name) return self._perform_delete( self._get_path('services/serviceBus/Namespaces', name), None) def check_namespace_availability(self, name): ''' Checks to see if the specified service bus namespace is available, or if it has already been taken. name: Name of the service bus namespace to validate. ''' _validate_not_none('name', name) response = self._perform_get( self._get_path('services/serviceBus/CheckNamespaceAvailability', None) + '/?namespace=' + _str(name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability( response.body) def list_queues(self, name): ''' Enumerates the queues in the service namespace. name: Name of the service bus namespace. ''' _validate_not_none('name', name) response = self._perform_get( self._get_list_queues_path(name), None) return _convert_response_to_feeds(response, QueueDescription) def list_topics(self, name): ''' Retrieves the topics in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_topics_path(name), None) return _convert_response_to_feeds(response, TopicDescription) def list_notification_hubs(self, name): ''' Retrieves the notification hubs in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_notification_hubs_path(name), None) return _convert_response_to_feeds(response, NotificationHubDescription) def list_relays(self, name): ''' Retrieves the relays in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_relays_path(name), None) return _convert_response_to_feeds(response, RelayDescription) #--Helper functions -------------------------------------------------- def _get_list_queues_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Queues' def _get_list_topics_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Topics' def _get_list_notification_hubs_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/NotificationHubs' def _get_list_relays_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Relays' ================================================ FILE: DSC/azure/servicemanagement/servicemanagementclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os from azure import ( WindowsAzureError, MANAGEMENT_HOST, _get_request_body, _parse_response, _str, _update_request_uri_query, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicemanagement import ( AZURE_MANAGEMENT_CERTFILE, AZURE_MANAGEMENT_SUBSCRIPTIONID, _management_error_handler, _parse_response_for_async_op, _update_management_header, ) class _ServiceManagementClient(object): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): self.requestid = None self.subscription_id = subscription_id self.cert_file = cert_file self.host = host if not self.cert_file: if AZURE_MANAGEMENT_CERTFILE in os.environ: self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE] if not self.subscription_id: if AZURE_MANAGEMENT_SUBSCRIPTIONID in os.environ: self.subscription_id = os.environ[ AZURE_MANAGEMENT_SUBSCRIPTIONID] if not self.cert_file or not self.subscription_id: raise WindowsAzureError( 'You need to provide subscription id and certificate file') self._httpclient = _HTTPClient( service_instance=self, cert_file=self.cert_file) self._filter = self._httpclient.perform_request def with_filter(self, filter): '''Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response.''' res = type(self)(self.subscription_id, self.cert_file, self.host) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) #--Helper functions -------------------------------------------------- def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _management_error_handler(ex) return resp def _perform_get(self, path, response_type): request = HTTPRequest() request.method = 'GET' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) return response def _perform_put(self, path, body, async=False): request = HTTPRequest() request.method = 'PUT' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _perform_post(self, path, body, response_type=None, async=False): request = HTTPRequest() request.method = 'POST' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) if async: return _parse_response_for_async_op(response) return None def _perform_delete(self, path, async=False): request = HTTPRequest() request.method = 'DELETE' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _get_path(self, resource, name): path = '/' + self.subscription_id + '/' + resource if name is not None: path += '/' + _str(name) return path ================================================ FILE: DSC/azure/servicemanagement/servicemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, MANAGEMENT_HOST, _str, _validate_not_none, ) from azure.servicemanagement import ( AffinityGroups, AffinityGroup, AvailabilityResponse, Certificate, Certificates, DataVirtualHardDisk, Deployment, Disk, Disks, Locations, Operation, HostedService, HostedServices, Images, OperatingSystems, OperatingSystemFamilies, OSImage, PersistentVMRole, StorageService, StorageServices, Subscription, SubscriptionCertificate, SubscriptionCertificates, VirtualNetworkSites, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for storage accounts ----------------------------------- def list_storage_accounts(self): ''' Lists the storage accounts available under the current subscription. ''' return self._perform_get(self._get_storage_service_path(), StorageServices) def get_storage_account_properties(self, service_name): ''' Returns system properties for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get(self._get_storage_service_path(service_name), StorageService) def get_storage_account_keys(self, service_name): ''' Returns the primary and secondary access keys for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path(service_name) + '/keys', StorageService) def regenerate_storage_account_keys(self, service_name, key_type): ''' Regenerates the primary or secondary access key for the specified storage account. service_name: Name of the storage service account. key_type: Specifies which key to regenerate. Valid values are: Primary, Secondary ''' _validate_not_none('service_name', service_name) _validate_not_none('key_type', key_type) return self._perform_post( self._get_storage_service_path( service_name) + '/keys?action=regenerate', _XmlSerializer.regenerate_keys_to_xml( key_type), StorageService) def create_storage_account(self, service_name, description, label, affinity_group=None, location=None, geo_replication_enabled=True, extended_properties=None): ''' Creates a new storage account in Windows Azure. service_name: A name for the storage account that is unique within Windows Azure. Storage account names must be between 3 and 24 characters in length and use numbers and lower-case letters only. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. affinity_group: The name of an existing affinity group in the specified subscription. You can specify either a location or affinity_group, but not both. location: The location where the storage account is created. You can specify either a location or affinity_group, but not both. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('description', description) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post( self._get_storage_service_path(), _XmlSerializer.create_storage_service_input_to_xml( service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties), async=True) def update_storage_account(self, service_name, description=None, label=None, geo_replication_enabled=None, extended_properties=None): ''' Updates the label, the description, and enables or disables the geo-replication status for a storage account in Windows Azure. service_name: Name of the storage service account. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put( self._get_storage_service_path(service_name), _XmlSerializer.update_storage_service_input_to_xml( description, label, geo_replication_enabled, extended_properties)) def delete_storage_account(self, service_name): ''' Deletes the specified storage account from Windows Azure. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_delete( self._get_storage_service_path(service_name)) def check_storage_account_name_availability(self, service_name): ''' Checks to see if the specified storage account name is available, or if it has already been taken. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path() + '/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for hosted services ------------------------------------ def list_hosted_services(self): ''' Lists the hosted services available under the current subscription. ''' return self._perform_get(self._get_hosted_service_path(), HostedServices) def get_hosted_service_properties(self, service_name, embed_detail=False): ''' Retrieves system properties for the specified hosted service. These properties include the service name and service type; the name of the affinity group to which the service belongs, or its location if it is not part of an affinity group; and optionally, information on the service's deployments. service_name: Name of the hosted service. embed_detail: When True, the management service returns properties for all deployments of the service, as well as for the service itself. ''' _validate_not_none('service_name', service_name) _validate_not_none('embed_detail', embed_detail) return self._perform_get( self._get_hosted_service_path(service_name) + '?embed-detail=' + _str(embed_detail).lower(), HostedService) def create_hosted_service(self, service_name, label, description=None, location=None, affinity_group=None, extended_properties=None): ''' Creates a new hosted service in Windows Azure. service_name: A name for the hosted service that is unique within Windows Azure. This name is the DNS prefix name and can be used to access the hosted service. label: A name for the hosted service. The name can be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. description: A description for the hosted service. The description can be up to 1024 characters in length. location: The location where the hosted service will be created. You can specify either a location or affinity_group, but not both. affinity_group: The name of an existing affinity group associated with this subscription. This name is a GUID and can be retrieved by examining the name element of the response body returned by list_affinity_groups. You can specify either a location or affinity_group, but not both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post(self._get_hosted_service_path(), _XmlSerializer.create_hosted_service_to_xml( service_name, label, description, location, affinity_group, extended_properties)) def update_hosted_service(self, service_name, label=None, description=None, extended_properties=None): ''' Updates the label and/or the description for a hosted service in Windows Azure. service_name: Name of the hosted service. label: A name for the hosted service. The name may be up to 100 characters in length. You must specify a value for either Label or Description, or for both. It is recommended that the label be unique within the subscription. The name can be used identify the hosted service for your tracking purposes. description: A description for the hosted service. The description may be up to 1024 characters in length. You must specify a value for either Label or Description, or for both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put(self._get_hosted_service_path(service_name), _XmlSerializer.update_hosted_service_to_xml( label, description, extended_properties)) def delete_hosted_service(self, service_name): ''' Deletes the specified hosted service from Windows Azure. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_delete(self._get_hosted_service_path(service_name)) def get_deployment_by_slot(self, service_name, deployment_slot): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) return self._perform_get( self._get_deployment_path_using_slot( service_name, deployment_slot), Deployment) def get_deployment_by_name(self, service_name, deployment_name): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_get( self._get_deployment_path_using_name( service_name, deployment_name), Deployment) def create_deployment(self, service_name, deployment_slot, name, package_url, label, configuration, start_deployment=False, treat_warnings_as_error=False, extended_properties=None): ''' Uploads a new service package and creates a new deployment on staging or production. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. configuration: The base-64 encoded service configuration file for the deployment. start_deployment: Indicates whether to start the deployment immediately after it is created. If false, the service model is still deployed to the virtual machines but the code is not run immediately. Instead, the service is Suspended until you call Update Deployment Status and set the status to Running, at which time the service will be started. A deployed service still incurs charges, even if it is suspended. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('name', name) _validate_not_none('package_url', package_url) _validate_not_none('label', label) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_slot( service_name, deployment_slot), _XmlSerializer.create_deployment_to_xml( name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties), async=True) def delete_deployment(self, service_name, deployment_name): ''' Deletes the specified deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_delete( self._get_deployment_path_using_name( service_name, deployment_name), async=True) def swap_deployment(self, service_name, production, source_deployment): ''' Initiates a virtual IP swap between the staging and production deployment environments for a service. If the service is currently running in the staging environment, it will be swapped to the production environment. If it is running in the production environment, it will be swapped to staging. service_name: Name of the hosted service. production: The name of the production deployment. source_deployment: The name of the source deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('production', production) _validate_not_none('source_deployment', source_deployment) return self._perform_post(self._get_hosted_service_path(service_name), _XmlSerializer.swap_deployment_to_xml( production, source_deployment), async=True) def change_deployment_configuration(self, service_name, deployment_name, configuration, treat_warnings_as_error=False, mode='Auto', extended_properties=None): ''' Initiates a change to the deployment configuration. service_name: Name of the hosted service. deployment_name: The name of the deployment. configuration: The base-64 encoded service configuration file for the deployment. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=config', _XmlSerializer.change_deployment_to_xml( configuration, treat_warnings_as_error, mode, extended_properties), async=True) def update_deployment_status(self, service_name, deployment_name, status): ''' Initiates a change in deployment status. service_name: Name of the hosted service. deployment_name: The name of the deployment. status: The change to initiate to the deployment status. Possible values include: Running, Suspended ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('status', status) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=status', _XmlSerializer.update_deployment_status_to_xml( status), async=True) def upgrade_deployment(self, service_name, deployment_name, mode, package_url, configuration, label, force, role_to_upgrade=None, extended_properties=None): ''' Initiates an upgrade. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. configuration: The base-64 encoded service configuration file for the deployment. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. role_to_upgrade: The name of the specific role to upgrade. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('package_url', package_url) _validate_not_none('configuration', configuration) _validate_not_none('label', label) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=upgrade', _XmlSerializer.upgrade_deployment_to_xml( mode, package_url, configuration, label, role_to_upgrade, force, extended_properties), async=True) def walk_upgrade_domain(self, service_name, deployment_name, upgrade_domain): ''' Specifies the next upgrade domain to be walked during manual in-place upgrade or configuration change. service_name: Name of the hosted service. deployment_name: The name of the deployment. upgrade_domain: An integer value that identifies the upgrade domain to walk. Upgrade domains are identified with a zero-based index: the first upgrade domain has an ID of 0, the second has an ID of 1, and so on. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('upgrade_domain', upgrade_domain) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=walkupgradedomain', _XmlSerializer.walk_upgrade_domain_to_xml( upgrade_domain), async=True) def rollback_update_or_upgrade(self, service_name, deployment_name, mode, force): ''' Cancels an in progress configuration change (update) or upgrade and returns the deployment to its state before the upgrade or configuration change was started. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: Specifies whether the rollback should proceed automatically. auto - The rollback proceeds without further user input. manual - You must call the Walk Upgrade Domain operation to apply the rollback to each upgrade domain. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=rollback', _XmlSerializer.rollback_upgrade_to_xml( mode, force), async=True) def reboot_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reboot of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reboot', '', async=True) def reimage_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reimage of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reimage', '', async=True) def check_hosted_service_name_availability(self, service_name): ''' Checks to see if the specified hosted service name is available, or if it has already been taken. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for service certificates ------------------------------- def list_service_certificates(self, service_name): ''' Lists all of the service certificates associated with the specified hosted service. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', Certificates) def get_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Returns the public data for the specified X.509 certificate associated with a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint) + '', Certificate) def add_service_certificate(self, service_name, data, certificate_format, password): ''' Adds a certificate to a hosted service. service_name: Name of the hosted service. data: The base-64 encoded form of the pfx file. certificate_format: The service certificate format. The only supported value is pfx. password: The certificate password. ''' _validate_not_none('service_name', service_name) _validate_not_none('data', data) _validate_not_none('certificate_format', certificate_format) _validate_not_none('password', password) return self._perform_post( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', _XmlSerializer.certificate_file_to_xml( data, certificate_format, password), async=True) def delete_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Deletes a service certificate from the certificate store of a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint), async=True) #--Operations for management certificates ---------------------------- def list_management_certificates(self): ''' The List Management Certificates operation lists and returns basic information about all of the management certificates associated with the specified subscription. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. ''' return self._perform_get('/' + self.subscription_id + '/certificates', SubscriptionCertificates) def get_management_certificate(self, thumbprint): ''' The Get Management Certificate operation retrieves information about the management certificate with the specified thumbprint. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumbprint value of the certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/certificates/' + _str(thumbprint), SubscriptionCertificate) def add_management_certificate(self, public_key, thumbprint, data): ''' The Add Management Certificate operation adds a certificate to the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. public_key: A base64 representation of the management certificate public key. thumbprint: The thumb print that uniquely identifies the management certificate. data: The certificate's raw data in base-64 encoded .cer format. ''' _validate_not_none('public_key', public_key) _validate_not_none('thumbprint', thumbprint) _validate_not_none('data', data) return self._perform_post( '/' + self.subscription_id + '/certificates', _XmlSerializer.subscription_certificate_to_xml( public_key, thumbprint, data)) def delete_management_certificate(self, thumbprint): ''' The Delete Management Certificate operation deletes a certificate from the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumb print that uniquely identifies the management certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/certificates/' + _str(thumbprint)) #--Operations for affinity groups ------------------------------------ def list_affinity_groups(self): ''' Lists the affinity groups associated with the specified subscription. ''' return self._perform_get( '/' + self.subscription_id + '/affinitygroups', AffinityGroups) def get_affinity_group_properties(self, affinity_group_name): ''' Returns the system properties associated with the specified affinity group. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_get( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name) + '', AffinityGroup) def create_affinity_group(self, name, label, location, description=None): ''' Creates a new affinity group for the specified subscription. name: A name for the affinity group that is unique to the subscription. label: A name for the affinity group. The name can be up to 100 characters in length. location: The data center location where the affinity group will be created. To list available locations, use the list_location function. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('name', name) _validate_not_none('label', label) _validate_not_none('location', location) return self._perform_post( '/' + self.subscription_id + '/affinitygroups', _XmlSerializer.create_affinity_group_to_xml(name, label, description, location)) def update_affinity_group(self, affinity_group_name, label, description=None): ''' Updates the label and/or the description for an affinity group for the specified subscription. affinity_group_name: The name of the affinity group. label: A name for the affinity group. The name can be up to 100 characters in length. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('affinity_group_name', affinity_group_name) _validate_not_none('label', label) return self._perform_put( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name), _XmlSerializer.update_affinity_group_to_xml(label, description)) def delete_affinity_group(self, affinity_group_name): ''' Deletes an affinity group in the specified subscription. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_delete('/' + self.subscription_id + \ '/affinitygroups/' + \ _str(affinity_group_name)) #--Operations for locations ------------------------------------------ def list_locations(self): ''' Lists all of the data center locations that are valid for your subscription. ''' return self._perform_get('/' + self.subscription_id + '/locations', Locations) #--Operations for tracking asynchronous requests --------------------- def get_operation_status(self, request_id): ''' Returns the status of the specified operation. After calling an asynchronous operation, you can call Get Operation Status to determine whether the operation has succeeded, failed, or is still in progress. request_id: The request ID for the request you wish to track. ''' _validate_not_none('request_id', request_id) return self._perform_get( '/' + self.subscription_id + '/operations/' + _str(request_id), Operation) #--Operations for retrieving operating system information ------------ def list_operating_systems(self): ''' Lists the versions of the guest operating system that are currently available in Windows Azure. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystems', OperatingSystems) def list_operating_system_families(self): ''' Lists the guest operating system families available in Windows Azure, and also lists the operating system versions available for each family. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystemfamilies', OperatingSystemFamilies) #--Operations for retrieving subscription history -------------------- def get_subscription(self): ''' Returns account and resource allocation information on the specified subscription. ''' return self._perform_get('/' + self.subscription_id + '', Subscription) #--Operations for virtual machines ----------------------------------- def get_role(self, service_name, deployment_name, role_name): ''' Retrieves the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_get( self._get_role_path(service_name, deployment_name, role_name), PersistentVMRole) def create_virtual_machine_deployment(self, service_name, deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole', virtual_network_name=None): ''' Provisions a virtual machine based on the supplied configuration. service_name: Name of the hosted service. deployment_name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production label: Specifies an identifier for the deployment. The label can be up to 100 characters long. The label can be used for tracking purposes. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. virtual_network_name: Specifies the name of an existing virtual network to which the deployment will belong. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('label', label) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_deployment_path_using_name(service_name), _XmlSerializer.virtual_machine_deployment_to_xml( deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name), async=True) def add_role(self, service_name, deployment_name, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Adds a virtual machine to an existing deployment. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_role_path(service_name, deployment_name), _XmlSerializer.add_role_to_xml( role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def update_role(self, service_name, deployment_name, role_name, os_virtual_hard_disk=None, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Updates the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_put( self._get_role_path(service_name, deployment_name, role_name), _XmlSerializer.update_role_to_xml( role_name, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def delete_role(self, service_name, deployment_name, role_name): ''' Deletes the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_delete( self._get_role_path(service_name, deployment_name, role_name), async=True) def capture_role(self, service_name, deployment_name, role_name, post_capture_action, target_image_name, target_image_label, provisioning_configuration=None): ''' The Capture Role operation captures a virtual machine image to your image gallery. From the captured image, you can create additional customized virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_capture_action: Specifies the action after capture operation completes. Possible values are: Delete, Reprovision. target_image_name: Specifies the image name of the captured virtual machine. target_image_label: Specifies the friendly name of the captured virtual machine. provisioning_configuration: Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_capture_action', post_capture_action) _validate_not_none('target_image_name', target_image_name) _validate_not_none('target_image_label', target_image_label) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.capture_role_to_xml( post_capture_action, target_image_name, target_image_label, provisioning_configuration), async=True) def start_role(self, service_name, deployment_name, role_name): ''' Starts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.start_role_operation_to_xml(), async=True) def start_roles(self, service_name, deployment_name, role_names): ''' Starts the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.start_roles_operation_to_xml(role_names), async=True) def restart_role(self, service_name, deployment_name, role_name): ''' Restarts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.restart_role_operation_to_xml( ), async=True) def shutdown_role(self, service_name, deployment_name, role_name, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.shutdown_role_operation_to_xml(post_shutdown_action), async=True) def shutdown_roles(self, service_name, deployment_name, role_names, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.shutdown_roles_operation_to_xml( role_names, post_shutdown_action), async=True) #--Operations for virtual machine images ----------------------------- def list_os_images(self): ''' Retrieves a list of the OS images from the image repository. ''' return self._perform_get(self._get_image_path(), Images) def get_os_image(self, image_name): ''' Retrieves an OS image from the image repository. ''' return self._perform_get(self._get_image_path(image_name), OSImage) def add_os_image(self, label, media_link, name, os): ''' Adds an OS image that is currently stored in a storage account in your subscription to the image repository. label: Specifies the friendly name of the image. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more virtual machines. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_image_path(), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def update_os_image(self, image_name, label, media_link, name, os): ''' Updates an OS image that in your image repository. image_name: The name of the image to update. label: Specifies the friendly name of the image to be updated. You cannot use this operation to update images provided by the Windows Azure platform. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more VM Roles. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('image_name', image_name) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_image_path(image_name), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def delete_os_image(self, image_name, delete_vhd=False): ''' Deletes the specified OS image from your image repository. image_name: The name of the image. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('image_name', image_name) path = self._get_image_path(image_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def get_data_disk(self, service_name, deployment_name, role_name, lun): ''' Retrieves the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_get( self._get_data_disk_path( service_name, deployment_name, role_name, lun), DataVirtualHardDisk) def add_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None, source_media_link=None): ''' Adds a data disk to a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. source_media_link: Specifies the location of a blob in account storage which is mounted as a data disk when the virtual machine is created. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_post( self._get_data_disk_path(service_name, deployment_name, role_name), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link), async=True) def update_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, updated_lun=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None): ''' Updates the specified data disk attached to the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd updated_lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_put( self._get_data_disk_path( service_name, deployment_name, role_name, lun), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, updated_lun, logical_disk_size_in_gb, media_link, None), async=True) def delete_data_disk(self, service_name, deployment_name, role_name, lun, delete_vhd=False): ''' Removes the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) path = self._get_data_disk_path(service_name, deployment_name, role_name, lun) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def list_disks(self): ''' Retrieves a list of the disks in your image repository. ''' return self._perform_get(self._get_disk_path(), Disks) def get_disk(self, disk_name): ''' Retrieves a disk from your image repository. ''' return self._perform_get(self._get_disk_path(disk_name), Disk) def add_disk(self, has_operating_system, label, media_link, name, os): ''' Adds a disk to the user image repository. The disk can be an OS disk or a data disk. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_disk_path(), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def update_disk(self, disk_name, has_operating_system, label, media_link, name, os): ''' Updates an existing disk in your image repository. disk_name: The name of the disk to update. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('disk_name', disk_name) _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_disk_path(disk_name), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def delete_disk(self, disk_name, delete_vhd=False): ''' Deletes the specified data or operating system disk from your image repository. disk_name: The name of the disk to delete. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('disk_name', disk_name) path = self._get_disk_path(disk_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path) #--Operations for virtual networks ------------------------------ def list_virtual_network_sites(self): ''' Retrieves a list of the virtual networks. ''' return self._perform_get(self._get_virtual_network_site_path(), VirtualNetworkSites) #--Helper functions -------------------------------------------------- def _get_virtual_network_site_path(self): return self._get_path('services/networking/virtualnetwork', None) def _get_storage_service_path(self, service_name=None): return self._get_path('services/storageservices', service_name) def _get_hosted_service_path(self, service_name=None): return self._get_path('services/hostedservices', service_name) def _get_deployment_path_using_slot(self, service_name, slot=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deploymentslots', slot) def _get_deployment_path_using_name(self, service_name, deployment_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments', deployment_name) def _get_role_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles', role_name) def _get_role_instance_operations_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roleinstances', role_name) + '/Operations' def _get_roles_operations_path(self, service_name, deployment_name): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles/Operations', None) def _get_data_disk_path(self, service_name, deployment_name, role_name, lun=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + _str(deployment_name) + '/roles/' + _str(role_name) + '/DataDisks', lun) def _get_disk_path(self, disk_name=None): return self._get_path('services/disks', disk_name) def _get_image_path(self, image_name=None): return self._get_path('services/images', image_name) ================================================ FILE: DSC/azure/servicemanagement/sqldatabasemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _parse_service_resources_response, ) from azure.servicemanagement import ( Servers, Database, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class SqlDatabaseManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on SQL Database management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(SqlDatabaseManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for sql servers ---------------------------------------- def list_servers(self): ''' List the SQL servers defined on the account. ''' return self._perform_get(self._get_list_servers_path(), Servers) #--Operations for sql databases ---------------------------------------- def list_databases(self, name): ''' List the SQL databases defined on the specified server name ''' response = self._perform_get(self._get_list_databases_path(name), None) return _parse_service_resources_response(response, Database) #--Helper functions -------------------------------------------------- def _get_list_servers_path(self): return self._get_path('services/sqlservers/servers', None) def _get_list_databases_path(self, name): # *contentview=generic is mandatory* return self._get_path('services/sqlservers/servers/', name) + '/databases?contentview=generic' ================================================ FILE: DSC/azure/servicemanagement/websitemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _str, ) from azure.servicemanagement import ( WebSpaces, WebSpace, Sites, Site, MetricResponses, MetricDefinitions, PublishData, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class WebsiteManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on WebSite management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(WebsiteManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for web sites ---------------------------------------- def list_webspaces(self): ''' List the webspaces defined on the account. ''' return self._perform_get(self._get_list_webspaces_path(), WebSpaces) def get_webspace(self, webspace_name): ''' Get details of a specific webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_webspace_details_path(webspace_name), WebSpace) def list_sites(self, webspace_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_sites_path(webspace_name), Sites) def get_site(self, webspace_name, website_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_sites_details_path(webspace_name, website_name), Site) def create_site(self, webspace_name, website_name, geo_region, host_names, plan='VirtualDedicatedPlan', compute_mode='Shared', server_farm=None, site_mode=None): ''' Create a website. webspace_name: The name of the webspace. website_name: The name of the website. geo_region: The geographical region of the webspace that will be created. host_names: An array of fully qualified domain names for website. Only one hostname can be specified in the azurewebsites.net domain. The hostname should match the name of the website. Custom domains can only be specified for Shared or Standard websites. plan: This value must be 'VirtualDedicatedPlan'. compute_mode: This value should be 'Shared' for the Free or Paid Shared offerings, or 'Dedicated' for the Standard offering. The default value is 'Shared'. If you set it to 'Dedicated', you must specify a value for the server_farm parameter. server_farm: The name of the Server Farm associated with this website. This is a required value for Standard mode. site_mode: Can be None, 'Limited' or 'Basic'. This value is 'Limited' for the Free offering, and 'Basic' for the Paid Shared offering. Standard mode does not use the site_mode parameter; it uses the compute_mode parameter. ''' xml = _XmlSerializer.create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode) return self._perform_post( self._get_sites_path(webspace_name), xml, Site) def delete_site(self, webspace_name, website_name, delete_empty_server_farm=False, delete_metrics=False): ''' Delete a website. webspace_name: The name of the webspace. website_name: The name of the website. delete_empty_server_farm: If the site being deleted is the last web site in a server farm, you can delete the server farm by setting this to True. delete_metrics: To also delete the metrics for the site that you are deleting, you can set this to True. ''' path = self._get_sites_details_path(webspace_name, website_name) query = '' if delete_empty_server_farm: query += '&deleteEmptyServerFarm=true' if delete_metrics: query += '&deleteMetrics=true' if query: path = path + '?' + query.lstrip('&') return self._perform_delete(path) def restart_site(self, webspace_name, website_name): ''' Restart a web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_post( self._get_restart_path(webspace_name, website_name), '') def get_historical_usage_metrics(self, webspace_name, website_name, metrics = None, start_time=None, end_time=None, time_grain=None): ''' Get historical usage metrics. webspace_name: The name of the webspace. website_name: The name of the website. metrics: Optional. List of metrics name. Otherwise, all metrics returned. start_time: Optional. An ISO8601 date. Otherwise, current hour is used. end_time: Optional. An ISO8601 date. Otherwise, current time is used. time_grain: Optional. A rollup name, as P1D. OTherwise, default rollup for the metrics is used. More information and metrics name at: http://msdn.microsoft.com/en-us/library/azure/dn166964.aspx ''' metrics = ('names='+','.join(metrics)) if metrics else '' start_time = ('StartTime='+start_time) if start_time else '' end_time = ('EndTime='+end_time) if end_time else '' time_grain = ('TimeGrain='+time_grain) if time_grain else '' parameters = ('&'.join(v for v in (metrics, start_time, end_time, time_grain) if v)) parameters = '?'+parameters if parameters else '' return self._perform_get(self._get_historical_usage_metrics_path(webspace_name, website_name) + parameters, MetricResponses) def get_metric_definitions(self, webspace_name, website_name): ''' Get metric definitions of metrics available of this web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_metric_definitions_path(webspace_name, website_name), MetricDefinitions) def get_publish_profile_xml(self, webspace_name, website_name): ''' Get a site's publish profile as a string webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), None).body.decode("utf-8") def get_publish_profile(self, webspace_name, website_name): ''' Get a site's publish profile as an object webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), PublishData) #--Helper functions -------------------------------------------------- def _get_list_webspaces_path(self): return self._get_path('services/webspaces', None) def _get_webspace_details_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) def _get_sites_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) + '/sites' def _get_sites_details_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) def _get_restart_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/restart/' def _get_historical_usage_metrics_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metrics/' def _get_metric_definitions_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metricdefinitions/' def _get_publishxml_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/publishxml/' ================================================ FILE: DSC/azure/storage/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import types from datetime import datetime from xml.dom import minidom from azure import (WindowsAzureData, WindowsAzureError, METADATA_NS, xml_escape, _create_entry, _decode_base64_to_text, _decode_base64_to_bytes, _encode_base64, _fill_data_minidom, _fill_instance_element, _get_child_nodes, _get_child_nodesNS, _get_children_from_path, _get_entry_properties, _general_error_handler, _list_of, _parse_response_for_dict, _sign_string, _unicode_type, _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY, ) # x-ms-version for storage service. X_MS_VERSION = '2012-02-12' class EnumResultsBase(object): ''' base class for EnumResults. ''' def __init__(self): self.prefix = u'' self.marker = u'' self.max_results = 0 self.next_marker = u'' class ContainerEnumResults(EnumResultsBase): ''' Blob Container list. ''' def __init__(self): EnumResultsBase.__init__(self) self.containers = _list_of(Container) def __iter__(self): return iter(self.containers) def __len__(self): return len(self.containers) def __getitem__(self, index): return self.containers[index] class Container(WindowsAzureData): ''' Blob container class. ''' def __init__(self): self.name = u'' self.url = u'' self.properties = Properties() self.metadata = {} class Properties(WindowsAzureData): ''' Blob container's properties class. ''' def __init__(self): self.last_modified = u'' self.etag = u'' class RetentionPolicy(WindowsAzureData): ''' RetentionPolicy in service properties. ''' def __init__(self): self.enabled = False self.__dict__['days'] = None def get_days(self): # convert days to int value return int(self.__dict__['days']) def set_days(self, value): ''' set default days if days is set to empty. ''' self.__dict__['days'] = value days = property(fget=get_days, fset=set_days) class Logging(WindowsAzureData): ''' Logging class in service properties. ''' def __init__(self): self.version = u'1.0' self.delete = False self.read = False self.write = False self.retention_policy = RetentionPolicy() class Metrics(WindowsAzureData): ''' Metrics class in service properties. ''' def __init__(self): self.version = u'1.0' self.enabled = False self.include_apis = None self.retention_policy = RetentionPolicy() class StorageServiceProperties(WindowsAzureData): ''' Storage Service Propeties class. ''' def __init__(self): self.logging = Logging() self.metrics = Metrics() class AccessPolicy(WindowsAzureData): ''' Access Policy class in service properties. ''' def __init__(self, start=u'', expiry=u'', permission='u'): self.start = start self.expiry = expiry self.permission = permission class SignedIdentifier(WindowsAzureData): ''' Signed Identifier class for service properties. ''' def __init__(self): self.id = u'' self.access_policy = AccessPolicy() class SignedIdentifiers(WindowsAzureData): ''' SignedIdentifier list. ''' def __init__(self): self.signed_identifiers = _list_of(SignedIdentifier) def __iter__(self): return iter(self.signed_identifiers) def __len__(self): return len(self.signed_identifiers) def __getitem__(self, index): return self.signed_identifiers[index] class BlobEnumResults(EnumResultsBase): ''' Blob list.''' def __init__(self): EnumResultsBase.__init__(self) self.blobs = _list_of(Blob) self.prefixes = _list_of(BlobPrefix) self.delimiter = '' def __iter__(self): return iter(self.blobs) def __len__(self): return len(self.blobs) def __getitem__(self, index): return self.blobs[index] class BlobResult(bytes): def __new__(cls, blob, properties): return bytes.__new__(cls, blob if blob else b'') def __init__(self, blob, properties): self.properties = properties class Blob(WindowsAzureData): ''' Blob class. ''' def __init__(self): self.name = u'' self.snapshot = u'' self.url = u'' self.properties = BlobProperties() self.metadata = {} class BlobProperties(WindowsAzureData): ''' Blob Properties ''' def __init__(self): self.last_modified = u'' self.etag = u'' self.content_length = 0 self.content_type = u'' self.content_encoding = u'' self.content_language = u'' self.content_md5 = u'' self.xms_blob_sequence_number = 0 self.blob_type = u'' self.lease_status = u'' self.lease_state = u'' self.lease_duration = u'' self.copy_id = u'' self.copy_source = u'' self.copy_status = u'' self.copy_progress = u'' self.copy_completion_time = u'' self.copy_status_description = u'' class BlobPrefix(WindowsAzureData): ''' BlobPrefix in Blob. ''' def __init__(self): self.name = '' class BlobBlock(WindowsAzureData): ''' BlobBlock class ''' def __init__(self, id=None, size=None): self.id = id self.size = size class BlobBlockList(WindowsAzureData): ''' BlobBlockList class ''' def __init__(self): self.committed_blocks = [] self.uncommitted_blocks = [] class PageRange(WindowsAzureData): ''' Page Range for page blob. ''' def __init__(self): self.start = 0 self.end = 0 class PageList(object): ''' Page list for page blob. ''' def __init__(self): self.page_ranges = _list_of(PageRange) def __iter__(self): return iter(self.page_ranges) def __len__(self): return len(self.page_ranges) def __getitem__(self, index): return self.page_ranges[index] class QueueEnumResults(EnumResultsBase): ''' Queue list''' def __init__(self): EnumResultsBase.__init__(self) self.queues = _list_of(Queue) def __iter__(self): return iter(self.queues) def __len__(self): return len(self.queues) def __getitem__(self, index): return self.queues[index] class Queue(WindowsAzureData): ''' Queue class ''' def __init__(self): self.name = u'' self.url = u'' self.metadata = {} class QueueMessagesList(WindowsAzureData): ''' Queue message list. ''' def __init__(self): self.queue_messages = _list_of(QueueMessage) def __iter__(self): return iter(self.queue_messages) def __len__(self): return len(self.queue_messages) def __getitem__(self, index): return self.queue_messages[index] class QueueMessage(WindowsAzureData): ''' Queue message class. ''' def __init__(self): self.message_id = u'' self.insertion_time = u'' self.expiration_time = u'' self.pop_receipt = u'' self.time_next_visible = u'' self.dequeue_count = u'' self.message_text = u'' class Entity(WindowsAzureData): ''' Entity class. The attributes of entity will be created dynamically. ''' pass class EntityProperty(WindowsAzureData): ''' Entity property. contains type and value. ''' def __init__(self, type=None, value=None): self.type = type self.value = value class Table(WindowsAzureData): ''' Only for intellicens and telling user the return type. ''' pass def _parse_blob_enum_results_list(response): respbody = response.body return_obj = BlobEnumResults() doc = minidom.parseString(respbody) for enum_results in _get_child_nodes(doc, 'EnumerationResults'): for child in _get_children_from_path(enum_results, 'Blobs', 'Blob'): return_obj.blobs.append(_fill_instance_element(child, Blob)) for child in _get_children_from_path(enum_results, 'Blobs', 'BlobPrefix'): return_obj.prefixes.append( _fill_instance_element(child, BlobPrefix)) for name, value in vars(return_obj).items(): if name == 'blobs' or name == 'prefixes': continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) return return_obj def _update_storage_header(request): ''' add additional headers for storage request. ''' if request.body: assert isinstance(request.body, bytes) # if it is PUT, POST, MERGE, DELETE, need to add content-lengt to header. if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append addtional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # append x-ms-meta name, values to header for name, value in request.headers: if 'x-ms-meta-name-values' in name and value: for meta_name, meta_value in value.items(): request.headers.append(('x-ms-meta-' + meta_name, meta_value)) request.headers.remove((name, value)) break return request def _update_storage_blob_header(request, account_name, account_key): ''' add additional headers for storage blob request. ''' request = _update_storage_header(request) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append( ('Content-Type', 'application/octet-stream Charset=UTF-8')) request.headers.append(('Authorization', _sign_storage_blob_request(request, account_name, account_key))) return request.headers def _update_storage_queue_header(request, account_name, account_key): ''' add additional headers for storage queue request. ''' return _update_storage_blob_header(request, account_name, account_key) def _update_storage_table_header(request): ''' add additional headers for storage table request. ''' request = _update_storage_header(request) for name, _ in request.headers: if name.lower() == 'content-type': break else: request.headers.append(('Content-Type', 'application/atom+xml')) request.headers.append(('DataServiceVersion', '2.0;NetFx')) request.headers.append(('MaxDataServiceVersion', '2.0;NetFx')) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append(('Date', current_time)) return request.headers def _sign_storage_blob_request(request, account_name, account_key): ''' Returns the signed string for blob request which is used to set Authorization header. This is also used to sign queue request. ''' uri_path = request.path.split('?')[0] # method to sign string_to_sign = request.method + '\n' # get headers to sign headers_to_sign = [ 'content-encoding', 'content-language', 'content-length', 'content-md5', 'content-type', 'date', 'if-modified-since', 'if-match', 'if-none-match', 'if-unmodified-since', 'range'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get x-ms header to sign x_ms_headers = [] for name, value in request.headers: if 'x-ms' in name: x_ms_headers.append((name.lower(), value)) x_ms_headers.sort() for name, value in x_ms_headers: if value: string_to_sign += ''.join([name, ':', value, '\n']) # get account_name and uri path to sign string_to_sign += '/' + account_name + uri_path # get query string to sign if it is not table service query_to_sign = request.query query_to_sign.sort() current_name = '' for name, value in query_to_sign: if value: if current_name != name: string_to_sign += '\n' + name + ':' + value else: string_to_sign += '\n' + ',' + value # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _sign_storage_table_request(request, account_name, account_key): uri_path = request.path.split('?')[0] string_to_sign = request.method + '\n' headers_to_sign = ['content-md5', 'content-type', 'date'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get account_name and uri path to sign string_to_sign += ''.join(['/', account_name, uri_path]) for name, value in request.query: if name == 'comp' and uri_path == '/': string_to_sign += '?comp=' + value break # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _to_python_bool(value): if value.lower() == 'true': return True return False def _to_entity_int(data): int_max = (2 << 30) - 1 if data > (int_max) or data < (int_max + 1) * (-1): return 'Edm.Int64', str(data) else: return 'Edm.Int32', str(data) def _to_entity_bool(value): if value: return 'Edm.Boolean', 'true' return 'Edm.Boolean', 'false' def _to_entity_datetime(value): return 'Edm.DateTime', value.strftime('%Y-%m-%dT%H:%M:%S') def _to_entity_float(value): return 'Edm.Double', str(value) def _to_entity_property(value): if value.type == 'Edm.Binary': return value.type, _encode_base64(value.value) return value.type, str(value.value) def _to_entity_none(value): return None, None def _to_entity_str(value): return 'Edm.String', value # Tables of conversions to and from entity types. We support specific # datatypes, and beyond that the user can use an EntityProperty to get # custom data type support. def _from_entity_binary(value): return EntityProperty('Edm.Binary', _decode_base64_to_bytes(value)) def _from_entity_int(value): return int(value) def _from_entity_datetime(value): format = '%Y-%m-%dT%H:%M:%S' if '.' in value: format = format + '.%f' if value.endswith('Z'): format = format + 'Z' return datetime.strptime(value, format) _ENTITY_TO_PYTHON_CONVERSIONS = { 'Edm.Binary': _from_entity_binary, 'Edm.Int32': _from_entity_int, 'Edm.Int64': _from_entity_int, 'Edm.Double': float, 'Edm.Boolean': _to_python_bool, 'Edm.DateTime': _from_entity_datetime, } # Conversion from Python type to a function which returns a tuple of the # type string and content string. _PYTHON_TO_ENTITY_CONVERSIONS = { int: _to_entity_int, bool: _to_entity_bool, datetime: _to_entity_datetime, float: _to_entity_float, EntityProperty: _to_entity_property, str: _to_entity_str, } if sys.version_info < (3,): _PYTHON_TO_ENTITY_CONVERSIONS.update({ long: _to_entity_int, types.NoneType: _to_entity_none, unicode: _to_entity_str, }) def _convert_entity_to_xml(source): ''' Converts an entity object to xml to send. The entity format is: <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' # construct the entity body included in <m:properties> and </m:properties> entity_body = '<m:properties xml:space="preserve">{properties}</m:properties>' if isinstance(source, WindowsAzureData): source = vars(source) properties_str = '' # set properties type for types we know if value has no type info. # if value has type info, then set the type to value.type for name, value in source.items(): mtype = '' conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(type(value)) if conv is None and sys.version_info >= (3,) and value is None: conv = _to_entity_none if conv is None: raise WindowsAzureError( _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY.format( type(value).__name__)) mtype, value = conv(value) # form the property node properties_str += ''.join(['<d:', name]) if value is None: properties_str += ' m:null="true" />' else: if mtype: properties_str += ''.join([' m:type="', mtype, '"']) properties_str += ''.join(['>', xml_escape(value), '</d:', name, '>']) if sys.version_info < (3,): if isinstance(properties_str, unicode): properties_str = properties_str.encode('utf-8') # generate the entity_body entity_body = entity_body.format(properties=properties_str) xmlstr = _create_entry(entity_body) return xmlstr def _convert_table_to_xml(table_name): ''' Create xml to send for a given table name. Since xml format for table is the same as entity and the only difference is that table has only one property 'TableName', so we just call _convert_entity_to_xml. table_name: the name of the table ''' return _convert_entity_to_xml({'TableName': table_name}) def _convert_block_list_to_xml(block_id_list): ''' Convert a block list to xml to send. block_id_list: a str list containing the block ids that are used in put_block_list. Only get block from latest blocks. ''' if block_id_list is None: return '' xml = '<?xml version="1.0" encoding="utf-8"?><BlockList>' for value in block_id_list: xml += '<Latest>{0}</Latest>'.format(_encode_base64(value)) return xml + '</BlockList>' def _create_blob_result(response): blob_properties = _parse_response_for_dict(response) return BlobResult(response.body, blob_properties) def _convert_response_to_block_list(response): ''' Converts xml response to block list class. ''' blob_block_list = BlobBlockList() xmldoc = minidom.parseString(response.body) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'CommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.committed_blocks.append( BlobBlock(xml_block_id, xml_block_size)) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'UncommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.uncommitted_blocks.append( BlobBlock(xml_block_id, xml_block_size)) return blob_block_list def _remove_prefix(name): colon = name.find(':') if colon != -1: return name[colon + 1:] return name def _convert_response_to_entity(response): if response is None: return response return _convert_xml_to_entity(response.body) def _convert_xml_to_entity(xmlstr): ''' Convert xml response to entity. The format of entity: <entry xmlns:d="http://schemas.microsoft.com/ado/2007/08/dataservices" xmlns:m="http://schemas.microsoft.com/ado/2007/08/dataservices/metadata" xmlns="http://www.w3.org/2005/Atom"> <title /> <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) xml_properties = None for entry in _get_child_nodes(xmldoc, 'entry'): for content in _get_child_nodes(entry, 'content'): # TODO: Namespace xml_properties = _get_child_nodesNS( content, METADATA_NS, 'properties') if not xml_properties: return None entity = Entity() # extract each property node and get the type from attribute and node value for xml_property in xml_properties[0].childNodes: name = _remove_prefix(xml_property.nodeName) # exclude the Timestamp since it is auto added by azure when # inserting entity. We don't want this to mix with real properties if name in ['Timestamp']: continue if xml_property.firstChild: value = xml_property.firstChild.nodeValue else: value = '' isnull = xml_property.getAttributeNS(METADATA_NS, 'null') mtype = xml_property.getAttributeNS(METADATA_NS, 'type') # if not isnull and no type info, then it is a string and we just # need the str type to hold the property. if not isnull and not mtype: _set_entity_attr(entity, name, value) elif isnull == 'true': if mtype: property = EntityProperty(mtype, None) else: property = EntityProperty('Edm.String', None) else: # need an object to hold the property conv = _ENTITY_TO_PYTHON_CONVERSIONS.get(mtype) if conv is not None: property = conv(value) else: property = EntityProperty(mtype, value) _set_entity_attr(entity, name, property) # extract id, updated and name value from feed entry and set them of # rule. for name, value in _get_entry_properties(xmlstr, True).items(): if name in ['etag']: _set_entity_attr(entity, name, value) return entity def _set_entity_attr(entity, name, value): try: setattr(entity, name, value) except UnicodeEncodeError: # Python 2 doesn't support unicode attribute names, so we'll # add them and access them directly through the dictionary entity.__dict__[name] = value def _convert_xml_to_table(xmlstr): ''' Converts the xml response to table class. Simply call convert_xml_to_entity and extract the table name, and add updated and author info ''' table = Table() entity = _convert_xml_to_entity(xmlstr) setattr(table, 'name', entity.TableName) for name, value in _get_entry_properties(xmlstr, False).items(): setattr(table, name, value) return table def _storage_error_handler(http_error): ''' Simple error handler for storage service. ''' return _general_error_handler(http_error) # make these available just from storage. from azure.storage.blobservice import BlobService from azure.storage.queueservice import QueueService from azure.storage.tableservice import TableService from azure.storage.cloudstorageaccount import CloudStorageAccount from azure.storage.sharedaccesssignature import ( SharedAccessSignature, SharedAccessPolicy, Permission, WebResource, ) ================================================ FILE: DSC/azure/storage/blobservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, BLOB_SERVICE_HOST_BASE, DEV_BLOB_HOST, _ERROR_VALUE_NEGATIVE, _ERROR_PAGE_BLOB_SIZE_ALIGNMENT, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _parse_simple_list, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_type_bytes, _validate_not_none, ) from azure.http import HTTPRequest from azure.storage import ( Container, ContainerEnumResults, PageList, PageRange, SignedIdentifiers, StorageServiceProperties, _convert_block_list_to_xml, _convert_response_to_block_list, _create_blob_result, _parse_blob_enum_results_list, _update_storage_blob_header, ) from azure.storage.storageclient import _StorageClient from os import path import sys if sys.version_info >= (3,): from io import BytesIO else: from cStringIO import StringIO as BytesIO # Keep this value sync with _ERROR_PAGE_BLOB_SIZE_ALIGNMENT _PAGE_SIZE = 512 class BlobService(_StorageClient): ''' This is the main class managing Blob resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=BLOB_SERVICE_HOST_BASE, dev_host=DEV_BLOB_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to https. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self._BLOB_MAX_DATA_SIZE = 64 * 1024 * 1024 self._BLOB_MAX_CHUNK_DATA_SIZE = 4 * 1024 * 1024 super(BlobService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def make_blob_url(self, container_name, blob_name, account_name=None, protocol=None, host_base=None): ''' Creates the url to access a blob. container_name: Name of container. blob_name: Name of blob. account_name: Name of the storage account. If not specified, uses the account specified when BlobService was initialized. protocol: Protocol to use: 'http' or 'https'. If not specified, uses the protocol specified when BlobService was initialized. host_base: Live host base url. If not specified, uses the host base specified when BlobService was initialized. ''' if not account_name: account_name = self.account_name if not protocol: protocol = self.protocol if not host_base: host_base = self.host_base return '{0}://{1}{2}/{3}/{4}'.format(protocol, account_name, host_base, container_name, blob_name) def list_containers(self, prefix=None, marker=None, maxresults=None, include=None): ''' The List Containers operation returns a list of the containers under the specified account. prefix: Optional. Filters the results to return only containers whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. maxresults: Optional. Specifies the maximum number of containers to return. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. set this parameter to string 'metadata' to get container's metadata. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list(response, ContainerEnumResults, "Containers", Container) def create_container(self, container_name, x_ms_meta_name_values=None, x_ms_blob_public_access=None, fail_on_exist=False): ''' Creates a new container under the specified account. If the container with the same name already exists, the operation fails. container_name: Name of container to create. x_ms_meta_name_values: Optional. A dict with name_value pairs to associate with the container as metadata. Example:{'Category':'test'} x_ms_blob_public_access: Optional. Possible values include: container, blob fail_on_exist: specify whether to throw an exception when the container exists. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def get_container_properties(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata and system properties for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_properties only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def get_container_metadata(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified container. The metadata will be in returned dictionary['x-ms-meta-(name)']. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_container_metadata(self, container_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets one or more user-defined name-value pairs for the specified container. container_name: Name of existing container. x_ms_meta_name_values: A dict containing name, value for metadata. Example: {'category':'test'} x_ms_lease_id: If specified, set_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_container_acl(self, container_name, x_ms_lease_id=None): ''' Gets the permissions for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, SignedIdentifiers) def set_container_acl(self, container_name, signed_identifiers=None, x_ms_blob_public_access=None, x_ms_lease_id=None): ''' Sets the permissions for the specified container. container_name: Name of existing container. signed_identifiers: SignedIdentifers instance x_ms_blob_public_access: Optional. Possible values include: container, blob x_ms_lease_id: If specified, set_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [ ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.body = _get_request_body( _convert_class_to_xml(signed_identifiers)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_container(self, container_name, fail_not_exist=False, x_ms_lease_id=None): ''' Marks the specified container for deletion. container_name: Name of container to delete. fail_not_exist: Specify whether to throw an exception when the container doesn't exist. x_ms_lease_id: Required if the container has an active lease. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def lease_container(self, container_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a lock on a container for delete operations. The lock duration can be 15 to 60 seconds, or can be infinite. container_name: Name of existing container. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the container has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none( x_ms_lease_duration if x_ms_lease_action == 'acquire'\ else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def list_blobs(self, container_name, prefix=None, marker=None, maxresults=None, include=None, delimiter=None): ''' Returns the list of blobs under the specified container. container_name: Name of existing container. prefix: Optional. Filters the results to return only blobs whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a marker value within the response body if the list returned was not complete. The marker value may then be used in a subsequent call to request the next set of list items. The marker value is opaque to the client. maxresults: Optional. Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not specify maxresults or specifies a value greater than 5,000, the server will return up to 5,000 items. Setting maxresults to a value less than or equal to zero results in error response code 400 (Bad Request). include: Optional. Specifies one or more datasets to include in the response. To specify more than one of these options on the URI, you must separate each option with a comma. Valid values are: snapshots: Specifies that snapshots should be included in the enumeration. Snapshots are listed from oldest to newest in the response. metadata: Specifies that blob metadata be returned in the response. uncommittedblobs: Specifies that blobs for which blocks have been uploaded, but which have not been committed using Put Block List (REST API), be included in the response. copy: Version 2012-02-12 and newer. Specifies that metadata related to any current or previous Copy Blob operation should be included in the response. delimiter: Optional. When the request includes this parameter, the operation returns a BlobPrefix element in the response body that acts as a placeholder for all blobs whose names begin with the same substring up to the appearance of the delimiter character. The delimiter may be a single character or a string. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('delimiter', _str_or_none(delimiter)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_blob_enum_results_list(response) def set_blob_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. You can also use this operation to set the default request version for all incoming requests that do not have a version specified. storage_service_properties: a StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_blob_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def get_blob_properties(self, container_name, blob_name, x_ms_lease_id=None): ''' Returns all user-defined metadata, standard HTTP properties, and system properties for the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'HEAD' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def set_blob_properties(self, container_name, blob_name, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_md5=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_lease_id=None): ''' Sets system properties on the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_blob_cache_control: Optional. Modifies the cache control string for the blob. x_ms_blob_content_type: Optional. Sets the blob's content type. x_ms_blob_content_md5: Optional. Sets the blob's MD5 hash. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. x_ms_blob_content_language: Optional. Sets the blob's content language. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=properties' request.headers = [ ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_blob(self, container_name, blob_name, blob, x_ms_blob_type, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_content_length=None, x_ms_blob_sequence_number=None): ''' Creates a new block blob or page blob, or updates the content of an existing block blob. See put_block_blob_from_* and put_page_blob_from_* for high level functions that handle the creation and upload of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: For BlockBlob: Content of blob as bytes (size < 64MB). For larger size, you must call put_block and put_block_list to set content of blob. For PageBlob: Use None and call put_page to set content of blob. x_ms_blob_type: Required. Could be BlockBlob or PageBlob. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_content_length: Required for page blobs. This header specifies the maximum size for the page blob, up to 1 TB. The page blob size must be aligned to a 512-byte boundary. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_blob_type', x_ms_blob_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-blob-type', _str_or_none(x_ms_blob_type)), ('Content-Encoding', _str_or_none(content_encoding)), ('Content-Language', _str_or_none(content_language)), ('Content-MD5', _str_or_none(content_md5)), ('Cache-Control', _str_or_none(cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-blob-content-length', _str_or_none(x_ms_blob_content_length)), ('x-ms-blob-sequence-number', _str_or_none(x_ms_blob_sequence_number)) ] request.body = _get_request_body_bytes_only('blob', blob) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file path, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_file(self, container_name, blob_name, stream, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file/stream, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is optional, but should be supplied for optimal performance. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) if count and count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = stream.read(count) self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, None, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) remain_bytes = count block_ids = [] block_index = 0 index = 0 while True: request_count = self._BLOB_MAX_CHUNK_DATA_SIZE\ if remain_bytes is None else min( remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) index += length remain_bytes = remain_bytes - \ length if remain_bytes else None block_id = '{0:08d}'.format(block_index) self.put_block(container_name, blob_name, data, block_id, x_ms_lease_id=x_ms_lease_id) block_ids.append(block_id) block_index += 1 if progress_callback: progress_callback(index, count) else: break self.put_block_list(container_name, blob_name, block_ids, content_md5, x_ms_blob_cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_meta_name_values, x_ms_lease_id) def put_block_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from an array of bytes, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_not_none('index', index) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index if count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = blob[index: index + count] self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: stream = BytesIO(blob) stream.seek(index) self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_text(self, container_name, blob_name, text, text_encoding='utf-8', content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from str/unicode, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. text: Text to upload to the blob. text_encoding: Encoding to use to convert the text to bytes. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text', text) if not isinstance(text, bytes): _validate_not_none('text_encoding', text_encoding) text = text.encode(text_encoding) self.put_block_blob_from_bytes(container_name, blob_name, text, 0, len(text), content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_page_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file path, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def put_page_blob_from_file(self, container_name, blob_name, stream, count, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file/stream, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is required, a page blob cannot be created if the count is unknown. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) _validate_not_none('count', count) if count < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('count')) if count % _PAGE_SIZE != 0: raise TypeError(_ERROR_PAGE_BLOB_SIZE_ALIGNMENT.format(count)) if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, b'', 'PageBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, count, x_ms_blob_sequence_number) remain_bytes = count page_start = 0 while True: request_count = min(remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) remain_bytes = remain_bytes - length page_end = page_start + length - 1 self.put_page(container_name, blob_name, data, 'bytes={0}-{1}'.format(page_start, page_end), 'update', x_ms_lease_id=x_ms_lease_id) page_start = page_start + length if progress_callback: progress_callback(page_start, count) else: break def put_page_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from an array of bytes, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index stream = BytesIO(blob) stream.seek(index) self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def get_blob(self, container_name, blob_name, snapshot=None, x_ms_range=None, x_ms_lease_id=None, x_ms_range_get_content_md5=None): ''' Reads or downloads a blob from the system, including its metadata and properties. See get_blob_to_* for high level functions that handle the download of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_range: Optional. Return only the bytes of the blob in the specified range. x_ms_lease_id: Required if the blob has an active lease. x_ms_range_get_content_md5: Optional. When this header is set to true and specified together with the Range header, the service returns the MD5 hash for the range, as long as the range is less than or equal to 4 MB in size. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-range-get-content-md5', _str_or_none(x_ms_range_get_content_md5)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request, None) return _create_blob_result(response) def get_blob_to_path(self, container_name, blob_name, file_path, open_mode='wb', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file path, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. file_path: Path of file to write to. open_mode: Mode to use when opening the file. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) _validate_not_none('open_mode', open_mode) with open(file_path, open_mode) as stream: self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) def get_blob_to_file(self, container_name, blob_name, stream, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file/stream, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. stream: Opened file/stream to write to. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) props = self.get_blob_properties(container_name, blob_name) blob_size = int(props['content-length']) if blob_size < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, blob_size) data = self.get_blob(container_name, blob_name, snapshot, x_ms_lease_id=x_ms_lease_id) stream.write(data) if progress_callback: progress_callback(blob_size, blob_size) else: if progress_callback: progress_callback(0, blob_size) index = 0 while index < blob_size: chunk_range = 'bytes={0}-{1}'.format( index, index + self._BLOB_MAX_CHUNK_DATA_SIZE - 1) data = self.get_blob( container_name, blob_name, x_ms_range=chunk_range) length = len(data) index += length if length > 0: stream.write(data) if progress_callback: progress_callback(index, blob_size) if length < self._BLOB_MAX_CHUNK_DATA_SIZE: break else: break def get_blob_to_bytes(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as an array of bytes, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) stream = BytesIO() self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) return stream.getvalue() def get_blob_to_text(self, container_name, blob_name, text_encoding='utf-8', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as unicode text, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. text_encoding: Encoding to use when decoding the blob data. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text_encoding', text_encoding) result = self.get_blob_to_bytes(container_name, blob_name, snapshot, x_ms_lease_id, progress_callback) return result.decode(text_encoding) def get_blob_metadata(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified blob or snapshot. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_blob_metadata(self, container_name, blob_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets user-defined metadata for the specified blob as one or more name-value pairs. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def lease_blob(self, container_name, blob_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a one-minute lock on a blob for write operations. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the blob has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none(x_ms_lease_duration\ if x_ms_lease_action == 'acquire' else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def snapshot_blob(self, container_name, blob_name, x_ms_meta_name_values=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None): ''' Creates a read-only snapshot of a blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Optional. Dict containing name and value pairs. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=snapshot' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-snapshot', 'etag', 'last-modified']) def copy_blob(self, container_name, blob_name, x_ms_copy_source, x_ms_meta_name_values=None, x_ms_source_if_modified_since=None, x_ms_source_if_unmodified_since=None, x_ms_source_if_match=None, x_ms_source_if_none_match=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None, x_ms_source_lease_id=None): ''' Copies a blob to a destination within the storage account. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_copy_source: URL up to 2 KB in length that specifies a blob. A source blob in the same account can be private, but a blob in another account must be public or accept credentials included in this URL, such as a Shared Access Signature. Examples: https://myaccount.blob.core.windows.net/mycontainer/myblob https://myaccount.blob.core.windows.net/mycontainer/myblob?snapshot=<DateTime> x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_source_if_modified_since: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. x_ms_source_if_unmodified_since: Optional. An ETag value. Specify this conditional header to copy the blob only if its ETag does not match the value specified. x_ms_source_if_match: Optional. A DateTime value. Specify this conditional header to copy the blob only if the source blob has been modified since the specified date/time. x_ms_source_if_none_match: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. Snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. x_ms_source_lease_id: Optional. Specify this to perform the Copy Blob operation only if the lease ID given matches the active lease ID of the source blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_source', x_ms_copy_source) if x_ms_copy_source.startswith('/'): # Backwards compatibility for earlier versions of the SDK where # the copy source can be in the following formats: # - Blob in named container: # /accountName/containerName/blobName # - Snapshot in named container: # /accountName/containerName/blobName?snapshot=<DateTime> # - Blob in root container: # /accountName/blobName # - Snapshot in root container: # /accountName/blobName?snapshot=<DateTime> account, _, source =\ x_ms_copy_source.partition('/')[2].partition('/') x_ms_copy_source = self.protocol + '://' + \ account + self.host_base + '/' + source request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-copy-source', _str_or_none(x_ms_copy_source)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-source-if-modified-since', _str_or_none(x_ms_source_if_modified_since)), ('x-ms-source-if-unmodified-since', _str_or_none(x_ms_source_if_unmodified_since)), ('x-ms-source-if-match', _str_or_none(x_ms_source_if_match)), ('x-ms-source-if-none-match', _str_or_none(x_ms_source_if_none_match)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-source-lease-id', _str_or_none(x_ms_source_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def abort_copy_blob(self, container_name, blob_name, x_ms_copy_id, x_ms_lease_id=None): ''' Aborts a pending copy_blob operation, and leaves a destination blob with zero length and full metadata. container_name: Name of destination container. blob_name: Name of destination blob. x_ms_copy_id: Copy identifier provided in the x-ms-copy-id of the original copy_blob operation. x_ms_lease_id: Required if the destination blob has an active infinite lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_id', x_ms_copy_id) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + \ _str(blob_name) + '?comp=copy©id=' + \ _str(x_ms_copy_id) request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-copy-action', 'abort'), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_blob(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Marks the specified blob or snapshot for deletion. The blob is later deleted during garbage collection. To mark a specific snapshot for deletion provide the date/time of the snapshot via the snapshot parameter. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to delete. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block(self, container_name, blob_name, block, blockid, content_md5=None, x_ms_lease_id=None): ''' Creates a new block to be committed as part of a blob. container_name: Name of existing container. blob_name: Name of existing blob. block: Content of the block. blockid: Required. A value that identifies the block. The string must be less than or equal to 64 bytes in size. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block', block) _validate_not_none('blockid', blockid) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=block' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('blockid', _encode_base64(_str_or_none(blockid)))] request.body = _get_request_body_bytes_only('block', block) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_list(self, container_name, blob_name, block_list, content_md5=None, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Writes a blob by specifying the list of block IDs that make up the blob. In order to be written as part of a blob, a block must have been successfully written to the server in a prior Put Block (REST API) operation. container_name: Name of existing container. blob_name: Name of existing blob. block_list: A str list containing the block ids. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_blob_cache_control: Optional. Sets the blob's cache control. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_type: Optional. Sets the blob's content type. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_language: Optional. Set the blob's content language. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_md5: Optional. An MD5 hash of the blob content. Note that this hash is not validated, as the hashes for the individual blocks were validated when each was uploaded. x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block_list', block_list) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.body = _get_request_body( _convert_block_list_to_xml(block_list)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_block_list(self, container_name, blob_name, snapshot=None, blocklisttype=None, x_ms_lease_id=None): ''' Retrieves the list of blocks that have been uploaded as part of a block blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. Datetime to determine the time to retrieve the blocks. blocklisttype: Specifies whether to return the list of committed blocks, the list of uncommitted blocks, or both lists together. Valid values are: committed, uncommitted, or all. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [ ('snapshot', _str_or_none(snapshot)), ('blocklisttype', _str_or_none(blocklisttype)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _convert_response_to_block_list(response) def put_page(self, container_name, blob_name, page, x_ms_range, x_ms_page_write, timeout=None, content_md5=None, x_ms_lease_id=None, x_ms_if_sequence_number_lte=None, x_ms_if_sequence_number_lt=None, x_ms_if_sequence_number_eq=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None): ''' Writes a range of pages to a page blob. container_name: Name of existing container. blob_name: Name of existing blob. page: Content of the page. x_ms_range: Required. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_page_write: Required. You may specify one of the following options: update (lower case): Writes the bytes specified by the request body into the specified range. The Range and Content-Length headers must match to perform the update. clear (lower case): Clears the specified range and releases the space used in storage for that range. To clear a range, set the Content-Length header to zero, and the Range header to a value that indicates the range to clear, up to maximum blob size. timeout: the timeout parameter is expressed in seconds. content_md5: Optional. An MD5 hash of the page content. This hash is used to verify the integrity of the page during transport. When this header is specified, the storage service compares the hash of the content that has arrived with the header value that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). x_ms_lease_id: Required if the blob has an active lease. x_ms_if_sequence_number_lte: Optional. If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_lt: Optional. If the blob's sequence number is less than the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_eq: Optional. If the blob's sequence number is equal to the specified value, the request proceeds; otherwise it fails. if_modified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has been modified since the specified date/time. If the blob has not been modified, the Blob service fails. if_unmodified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has not been modified since the specified date/time. If the blob has been modified, the Blob service fails. if_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value matches the value specified. If the values do not match, the Blob service fails. if_none_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value does not match the value specified. If the values are identical, the Blob service fails. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('page', page) _validate_not_none('x_ms_range', x_ms_range) _validate_not_none('x_ms_page_write', x_ms_page_write) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=page' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('Content-MD5', _str_or_none(content_md5)), ('x-ms-page-write', _str_or_none(x_ms_page_write)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-if-sequence-number-le', _str_or_none(x_ms_if_sequence_number_lte)), ('x-ms-if-sequence-number-lt', _str_or_none(x_ms_if_sequence_number_lt)), ('x-ms-if-sequence-number-eq', _str_or_none(x_ms_if_sequence_number_eq)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)) ] request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body_bytes_only('page', page) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_page_ranges(self, container_name, blob_name, snapshot=None, range=None, x_ms_range=None, x_ms_lease_id=None): ''' Retrieves the page ranges for a blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve information from. range: Optional. Specifies the range of bytes over which to list ranges, inclusively. If omitted, then all ranges for the blob are returned. x_ms_range: Optional. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=pagelist' request.headers = [ ('Range', _str_or_none(range)), ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_simple_list(response, PageList, PageRange, "page_ranges") ================================================ FILE: DSC/azure/storage/cloudstorageaccount.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure.storage.blobservice import BlobService from azure.storage.tableservice import TableService from azure.storage.queueservice import QueueService class CloudStorageAccount(object): """ Provides a factory for creating the blob, queue, and table services with a common account name and account key. Users can either use the factory or can construct the appropriate service directly. """ def __init__(self, account_name=None, account_key=None): self.account_name = account_name self.account_key = account_key def create_blob_service(self): return BlobService(self.account_name, self.account_key) def create_table_service(self): return TableService(self.account_name, self.account_key) def create_queue_service(self): return QueueService(self.account_name, self.account_key) ================================================ FILE: DSC/azure/storage/queueservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureConflictError, WindowsAzureError, DEV_QUEUE_HOST, QUEUE_SERVICE_HOST_BASE, xml_escape, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, _ERROR_CONFLICT, ) from azure.http import ( HTTPRequest, HTTP_RESPONSE_NO_CONTENT, ) from azure.storage import ( Queue, QueueEnumResults, QueueMessagesList, StorageServiceProperties, _update_storage_queue_header, ) from azure.storage.storageclient import _StorageClient class QueueService(_StorageClient): ''' This is the main class managing queue resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=QUEUE_SERVICE_HOST_BASE, dev_host=DEV_QUEUE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(QueueService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def get_queue_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Queue Service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def list_queues(self, prefix=None, marker=None, maxresults=None, include=None): ''' Lists all of the queues in a given storage account. prefix: Filters the results to return only queues with names that begin with the specified prefix. marker: A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a NextMarker element within the response body if the list returned was not complete. This value may then be used as a query parameter in a subsequent call to request the next portion of the list of queues. The marker value is opaque to the client. maxresults: Specifies the maximum number of queues to return. If maxresults is not specified, the server will return up to 5,000 items. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list( response, QueueEnumResults, "Queues", Queue) def create_queue(self, queue_name, x_ms_meta_name_values=None, fail_on_exist=False): ''' Creates a queue under the given account. queue_name: name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. fail_on_exist: Specify whether throw exception when queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_on_exist: try: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: return False return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(response.message)) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Permanently deletes the specified queue. queue_name: Name of the queue. fail_not_exist: Specify whether throw exception when queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue_metadata(self, queue_name): ''' Retrieves user-defined metadata and queue properties on the specified queue. Metadata is associated with the queue as name-values pairs. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix( response, prefixes=['x-ms-meta', 'x-ms-approximate-messages-count']) def set_queue_metadata(self, queue_name, x_ms_meta_name_values=None): ''' Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. queue_name: Name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def put_message(self, queue_name, message_text, visibilitytimeout=None, messagettl=None): ''' Adds a new message to the back of the message queue. A visibility timeout can also be specified to make the message invisible until the visibility timeout expires. A message must be in a format that can be included in an XML request with UTF-8 encoding. The encoded message can be up to 64KB in size for versions 2011-08-18 and newer, or 8KB in size for previous versions. queue_name: Name of the queue. message_text: Message content. visibilitytimeout: Optional. If not specified, the default value is 0. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibilitytimeout should be set to a value smaller than the time-to-live value. messagettl: Optional. Specifies the time-to-live interval for the message, in seconds. The maximum time-to-live allowed is 7 days. If this parameter is omitted, the default time-to-live is 7 days. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_text', message_text) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('visibilitytimeout', _str_or_none(visibilitytimeout)), ('messagettl', _str_or_none(messagettl)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def get_messages(self, queue_name, numofmessages=None, visibilitytimeout=None): ''' Retrieves one or more messages from the front of the queue. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to retrieve from the queue, up to a maximum of 32. If fewer are visible, the visible messages are returned. By default, a single message is retrieved from the queue with this operation. visibilitytimeout: Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 1 second, and cannot be larger than 7 days, or larger than 2 hours on REST protocol versions prior to version 2011-08-18. The visibility timeout of a message can be set to a value later than the expiry time. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('numofmessages', _str_or_none(numofmessages)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def peek_messages(self, queue_name, numofmessages=None): ''' Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to peek from the queue, up to a maximum of 32. By default, a single message is peeked from the queue with this operation. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages?peekonly=true' request.query = [('numofmessages', _str_or_none(numofmessages))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def delete_message(self, queue_name, message_id, popreceipt): ''' Deletes the specified message. queue_name: Name of the queue. message_id: Message to delete. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('popreceipt', popreceipt) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [('popreceipt', _str_or_none(popreceipt))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def clear_messages(self, queue_name): ''' Deletes all messages from the specified queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def update_message(self, queue_name, message_id, message_text, popreceipt, visibilitytimeout): ''' Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. queue_name: Name of the queue. message_id: Message to update. message_text: Content of message. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. visibilitytimeout: Required. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. A message can be updated until it has been deleted or has expired. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('message_text', message_text) _validate_not_none('popreceipt', popreceipt) _validate_not_none('visibilitytimeout', visibilitytimeout) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [ ('popreceipt', _str_or_none(popreceipt)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-popreceipt', 'x-ms-time-next-visible']) def set_queue_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Queue service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) ================================================ FILE: DSC/azure/storage/sharedaccesssignature.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import _sign_string, url_quote from azure.storage import X_MS_VERSION #------------------------------------------------------------------------- # Constants for the share access signature SIGNED_START = 'st' SIGNED_EXPIRY = 'se' SIGNED_RESOURCE = 'sr' SIGNED_PERMISSION = 'sp' SIGNED_IDENTIFIER = 'si' SIGNED_SIGNATURE = 'sig' SIGNED_VERSION = 'sv' RESOURCE_BLOB = 'b' RESOURCE_CONTAINER = 'c' SIGNED_RESOURCE_TYPE = 'resource' SHARED_ACCESS_PERMISSION = 'permission' #-------------------------------------------------------------------------- class WebResource(object): ''' Class that stands for the resource to get the share access signature path: the resource path. properties: dict of name and values. Contains 2 item: resource type and permission request_url: the url of the webresource include all the queries. ''' def __init__(self, path=None, request_url=None, properties=None): self.path = path self.properties = properties or {} self.request_url = request_url class Permission(object): ''' Permission class. Contains the path and query_string for the path. path: the resource path query_string: dict of name, values. Contains SIGNED_START, SIGNED_EXPIRY SIGNED_RESOURCE, SIGNED_PERMISSION, SIGNED_IDENTIFIER, SIGNED_SIGNATURE name values. ''' def __init__(self, path=None, query_string=None): self.path = path self.query_string = query_string class SharedAccessPolicy(object): ''' SharedAccessPolicy class. ''' def __init__(self, access_policy, signed_identifier=None): self.id = signed_identifier self.access_policy = access_policy class SharedAccessSignature(object): ''' The main class used to do the signing and generating the signature. account_name: the storage account name used to generate shared access signature account_key: the access key to genenerate share access signature permission_set: the permission cache used to signed the request url. ''' def __init__(self, account_name, account_key, permission_set=None): self.account_name = account_name self.account_key = account_key self.permission_set = permission_set def generate_signed_query_string(self, path, resource_type, shared_access_policy, version=X_MS_VERSION): ''' Generates the query string for path, resource type and shared access policy. path: the resource resource_type: could be blob or container shared_access_policy: shared access policy version: x-ms-version for storage service, or None to get a signed query string compatible with pre 2012-02-12 clients, where the version is not included in the query string. ''' query_string = {} if shared_access_policy.access_policy.start: query_string[ SIGNED_START] = shared_access_policy.access_policy.start if version: query_string[SIGNED_VERSION] = version query_string[SIGNED_EXPIRY] = shared_access_policy.access_policy.expiry query_string[SIGNED_RESOURCE] = resource_type query_string[ SIGNED_PERMISSION] = shared_access_policy.access_policy.permission if shared_access_policy.id: query_string[SIGNED_IDENTIFIER] = shared_access_policy.id query_string[SIGNED_SIGNATURE] = self._generate_signature( path, shared_access_policy, version) return query_string def sign_request(self, web_resource): ''' sign request to generate request_url with sharedaccesssignature info for web_resource.''' if self.permission_set: for shared_access_signature in self.permission_set: if self._permission_matches_request( shared_access_signature, web_resource, web_resource.properties[ SIGNED_RESOURCE_TYPE], web_resource.properties[SHARED_ACCESS_PERMISSION]): if web_resource.request_url.find('?') == -1: web_resource.request_url += '?' else: web_resource.request_url += '&' web_resource.request_url += self._convert_query_string( shared_access_signature.query_string) break return web_resource def _convert_query_string(self, query_string): ''' Converts query string to str. The order of name, values is very important and can't be wrong.''' convert_str = '' if SIGNED_START in query_string: convert_str += SIGNED_START + '=' + \ url_quote(query_string[SIGNED_START]) + '&' convert_str += SIGNED_EXPIRY + '=' + \ url_quote(query_string[SIGNED_EXPIRY]) + '&' convert_str += SIGNED_PERMISSION + '=' + \ query_string[SIGNED_PERMISSION] + '&' convert_str += SIGNED_RESOURCE + '=' + \ query_string[SIGNED_RESOURCE] + '&' if SIGNED_IDENTIFIER in query_string: convert_str += SIGNED_IDENTIFIER + '=' + \ query_string[SIGNED_IDENTIFIER] + '&' if SIGNED_VERSION in query_string: convert_str += SIGNED_VERSION + '=' + \ query_string[SIGNED_VERSION] + '&' convert_str += SIGNED_SIGNATURE + '=' + \ url_quote(query_string[SIGNED_SIGNATURE]) + '&' return convert_str def _generate_signature(self, path, shared_access_policy, version): ''' Generates signature for a given path and shared access policy. ''' def get_value_to_append(value, no_new_line=False): return_value = '' if value: return_value = value if not no_new_line: return_value += '\n' return return_value if path[0] != '/': path = '/' + path canonicalized_resource = '/' + self.account_name + path # Form the string to sign from shared_access_policy and canonicalized # resource. The order of values is important. string_to_sign = \ (get_value_to_append(shared_access_policy.access_policy.permission) + get_value_to_append(shared_access_policy.access_policy.start) + get_value_to_append(shared_access_policy.access_policy.expiry) + get_value_to_append(canonicalized_resource)) if version: string_to_sign += get_value_to_append(shared_access_policy.id) string_to_sign += get_value_to_append(version, True) else: string_to_sign += get_value_to_append(shared_access_policy.id, True) return self._sign(string_to_sign) def _permission_matches_request(self, shared_access_signature, web_resource, resource_type, required_permission): ''' Check whether requested permission matches given shared_access_signature, web_resource and resource type. ''' required_resource_type = resource_type if required_resource_type == RESOURCE_BLOB: required_resource_type += RESOURCE_CONTAINER for name, value in shared_access_signature.query_string.items(): if name == SIGNED_RESOURCE and \ required_resource_type.find(value) == -1: return False elif name == SIGNED_PERMISSION and \ required_permission.find(value) == -1: return False return web_resource.path.find(shared_access_signature.path) != -1 def _sign(self, string_to_sign): ''' use HMAC-SHA256 to sign the string and convert it as base64 encoded string. ''' return _sign_string(self.account_key, string_to_sign) ================================================ FILE: DSC/azure/storage/storageclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os import sys from azure import ( WindowsAzureError, DEV_ACCOUNT_NAME, DEV_ACCOUNT_KEY, _ERROR_STORAGE_MISSING_INFO, ) from azure.http import HTTPError from azure.http.httpclient import _HTTPClient from azure.storage import _storage_error_handler #-------------------------------------------------------------------------- # constants for azure app setting environment variables AZURE_STORAGE_ACCOUNT = 'AZURE_STORAGE_ACCOUNT' AZURE_STORAGE_ACCESS_KEY = 'AZURE_STORAGE_ACCESS_KEY' EMULATED = 'EMULATED' #-------------------------------------------------------------------------- class _StorageClient(object): ''' This is the base class for BlobManager, TableManager and QueueManager. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base='', dev_host=''): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self.account_name = account_name self.account_key = account_key self.requestid = None self.protocol = protocol self.host_base = host_base self.dev_host = dev_host # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. self.use_local_storage = False # check whether it is run in emulator. if EMULATED in os.environ: self.is_emulated = os.environ[EMULATED].lower() != 'false' else: self.is_emulated = False # get account_name and account key. If they are not set when # constructing, get the account and key from environment variables if # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. if not self.account_name or not self.account_key: if self.is_emulated: self.account_name = DEV_ACCOUNT_NAME self.account_key = DEV_ACCOUNT_KEY self.protocol = 'http' self.use_local_storage = True else: self.account_name = os.environ.get(AZURE_STORAGE_ACCOUNT) self.account_key = os.environ.get(AZURE_STORAGE_ACCESS_KEY) if not self.account_name or not self.account_key: raise WindowsAzureError(_ERROR_STORAGE_MISSING_INFO) self._httpclient = _HTTPClient( service_instance=self, account_key=self.account_key, account_name=self.account_name, protocol=self.protocol) self._batchclient = None self._filter = self._perform_request_worker def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = type(self)(self.account_name, self.account_key, self.protocol) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def _get_host(self): if self.use_local_storage: return self.dev_host else: return self.account_name + self.host_base def _perform_request_worker(self, request): return self._httpclient.perform_request(request) def _perform_request(self, request, text_encoding='utf-8'): ''' Sends the request and return response. Catches HTTPError and hand it to error handler ''' try: if self._batchclient is not None: return self._batchclient.insert_request_to_batch(request) else: resp = self._filter(request) if sys.version_info >= (3,) and isinstance(resp, bytes) and \ text_encoding: resp = resp.decode(text_encoding) except HTTPError as ex: _storage_error_handler(ex) return resp ================================================ FILE: DSC/azure/storage/tableservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, TABLE_SERVICE_HOST_BASE, DEV_TABLE_HOST, _convert_class_to_xml, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, ) from azure.http import HTTPRequest from azure.http.batchclient import _BatchClient from azure.storage import ( StorageServiceProperties, _convert_entity_to_xml, _convert_response_to_entity, _convert_table_to_xml, _convert_xml_to_entity, _convert_xml_to_table, _sign_storage_table_request, _update_storage_table_header, ) from azure.storage.storageclient import _StorageClient class TableService(_StorageClient): ''' This is the main class managing Table resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=TABLE_SERVICE_HOST_BASE, dev_host=DEV_TABLE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(TableService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def begin_batch(self): if self._batchclient is None: self._batchclient = _BatchClient( service_instance=self, account_key=self.account_key, account_name=self.account_name) return self._batchclient.begin_batch() def commit_batch(self): try: ret = self._batchclient.commit_batch() finally: self._batchclient = None return ret def cancel_batch(self): self._batchclient = None def get_table_service_properties(self): ''' Gets the properties of a storage account's Table service, including Windows Azure Storage Analytics. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def set_table_service_properties(self, storage_service_properties): ''' Sets the properties of a storage account's Table Service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict(response) def query_tables(self, table_name=None, top=None, next_table_name=None): ''' Returns a list of tables under the specified account. table_name: Optional. The specific table to query. top: Optional. Maximum number of tables to return. next_table_name: Optional. When top is used, the next table name is stored in result.x_ms_continuation['NextTableName'] ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() if table_name is not None: uri_part_table_name = "('" + table_name + "')" else: uri_part_table_name = "" request.path = '/Tables' + uri_part_table_name + '' request.query = [ ('$top', _int_or_none(top)), ('NextTableName', _str_or_none(next_table_name)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_table) def create_table(self, table, fail_on_exist=False): ''' Creates a new table in the storage account. table: Name of the table to create. Table name may contain only alphanumeric characters and cannot begin with a numeric character. It is case-insensitive and must be from 3 to 63 characters long. fail_on_exist: Specify whether throw exception when table exists. ''' _validate_not_none('table', table) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/Tables' request.body = _get_request_body(_convert_table_to_xml(table)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_table(self, table_name, fail_not_exist=False): ''' table_name: Name of the table to delete. fail_not_exist: Specify whether throw exception when table doesn't exist. ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/Tables(\'' + _str(table_name) + '\')' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_entity(self, table_name, partition_key, row_key, select=''): ''' Get an entity in a table; includes the $select options. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. select: Property names to select. ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('select', select) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + \ '(PartitionKey=\'' + _str(partition_key) + \ '\',RowKey=\'' + \ _str(row_key) + '\')?$select=' + \ _str(select) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def query_entities(self, table_name, filter=None, select=None, top=None, next_partition_key=None, next_row_key=None): ''' Get entities in a table; includes the $filter and $select options. table_name: Table to query. filter: Optional. Filter as described at http://msdn.microsoft.com/en-us/library/windowsazure/dd894031.aspx select: Optional. Property names to select from the entities. top: Optional. Maximum number of entities to return. next_partition_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextPartitionKey'] next_row_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextRowKey'] ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + '()' request.query = [ ('$filter', _str_or_none(filter)), ('$select', _str_or_none(select)), ('$top', _int_or_none(top)), ('NextPartitionKey', _str_or_none(next_partition_key)), ('NextRowKey', _str_or_none(next_row_key)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_entity) def insert_entity(self, table_name, entity, content_type='application/atom+xml'): ''' Inserts a new entity into a table. table_name: Table name. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(table_name) + '' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def update_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity in a table. The Update Entity operation replaces the entire entity and can be used to remove properties. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity by updating the entity's properties. This operation does not replace the existing entity as the Update Entity operation does. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Can be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def delete_entity(self, table_name, partition_key, row_key, content_type='application/atom+xml', if_match='*'): ''' Deletes an existing entity in a table. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the delete should be performed. To force an unconditional delete, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('content_type', content_type) _validate_not_none('if_match', if_match) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) self._perform_request(request) def insert_or_replace_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Replaces an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def insert_or_merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Merges an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def _perform_request_worker(self, request): auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) return self._httpclient.perform_request(request) ================================================ FILE: DSC/curlhttpclient.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """Curl CLI wrapper.""" import base64 import random import subprocess import time import traceback import os import sys import subprocessfactory from httpclient import * json = serializerfactory.get_serializer(sys.version_info) CURL_ALIAS = "curl" CURL_HTTP_CODE_SPECIAL_VAR = "%{http_code}" OPTION_LOCATION = "--location" OPTION_SILENT = "--silent" OPTION_CERT = "--cert" OPTION_KEY = "--key" OPTION_WRITE_OUT = "--write-out" OPTION_HEADER = "--header" OPTION_REQUEST = "--request" OPTION_INSECURE = "--insecure" OPTION_DATA = "--data" OPTION_PROXY = "--proxy" OPTION_CONNECT_TIMEOUT = "--connect-timeout" OPTION_MAX_TIME = "--max-time" OPTION_RETRY = "--retry" OPTION_RETRY_DELAY = "--retry-delay" OPTION_RETRY_MAX_TIME = "--retry-max-time" # maximum time in seconds that you allow the whole operation to take VALUE_MAX_TIME = "30" # this only limits the connection phase, it has no impact once it has connected VALUE_CONNECT_TIMEOUT = "15" # if a transient error is returned when curl tries to perform a transfer, it will retry this number of times # before giving up VALUE_RETRY = "3" # make curl sleep this amount of time before each retry when a transfer has failed with a transient VALUE_RETRY_DELAY = "3" # retries will be done as usual as long as the timer hasn't reached this given limit VALUE_RETRY_MAX_TIME = "60" # curl status delimiter STATUS_CODE_DELIMITER = "\n\nstatus_code:" # curl success exit code EXIT_SUCCESS = 0 class CurlHttpClient(HttpClient): """Curl CLI wrapper. Inherits from HttpClient. Targets : [2.4.0 - 2.7.9[ Implements the following method common to all classes inheriting HttpClient. get (url, headers) post (url, headers, data) Curl documentation : CLI : https://curl.haxx.se/docs/manpage.html Error code : https://curl.haxx.se/libcurl/c/libcurl-errors.html """ @staticmethod def parse_raw_output(output): """Parses stdout from Curl to extract response_body and status_code. Args: output : string, raw stdout from curl subprocess. The format of the raw output should be of the following format (example request to www.microsoft.com): <html><head><title>Microsoft Corporation

Your current User-Agent string appears to be from an automated process, if his is incorrect, please click this link:United States English Microsoft Homepage

status_code:200 Returns: A RequestResponse """ start_index = output.index(STATUS_CODE_DELIMITER) response_body = output[:start_index] status_code = output[start_index:].strip("\n").split(":")[1] return RequestResponse(status_code, response_body) def get_base_cmd(self): """Creates the base cmd array to invoke the Curl CLI. Adds the following arguments for all request: --location : Retry the request if the requested page has moved to a different location --silent : Silent or quiet mode Adds the following optional arguments --cert : Tells curl to use the specified client certificate file when getting a file with HTTPS --key : Private key file name Returns: An array containing all required arguments to invoke curl, example: ["curl", "--location", "--silent", "--cert", "my_cert_file.crt", "--key", "my_key_file.key"] """ # basic options cmd = [CURL_ALIAS, OPTION_LOCATION, OPTION_SILENT] # retry and timeout options cmd += [OPTION_CONNECT_TIMEOUT, VALUE_CONNECT_TIMEOUT, OPTION_MAX_TIME, VALUE_MAX_TIME, OPTION_RETRY, VALUE_RETRY, OPTION_RETRY_DELAY, VALUE_RETRY_DELAY, OPTION_RETRY_MAX_TIME, VALUE_RETRY_MAX_TIME] if self.cert_path is not None: cmd.extend([OPTION_CERT, self.cert_path, OPTION_KEY, self.key_path]) if self.proxy_configuration is not None: cmd.extend([OPTION_PROXY, self.proxy_configuration]) return cmd def build_request_cmd(self, url, headers, method=None, data_file_path=None): """Formats the final cmd array to invoke Curl. The final cmd is created from the based command and additional optional parameters. Args: url : string , the URL. headers : dictionary, contains the required headers. method : string , specifies the http method to use. data_file_path : string , data file path. Adds the following arguments to the base cmd when required: --write-out : Makes curl display information on stdout after a completed transfer (i.e status_code). --header : Extra headers to include in the request when sending the request. --request : Specifies a custom request method to use for the request. --insecure : Explicitly allows curl to perform "insecure" SSL connections and transfers. Returns: An array containing the base cmd concatenated with any required extra argument, example: ["curl", "--location", "--silent", "--cert", "my_cert_file.crt", "--key", "my_key_file.key", "--insecure", "https://www.microsoft.com"] """ cmd = self.get_base_cmd() cmd.append(OPTION_WRITE_OUT) cmd.append(STATUS_CODE_DELIMITER + CURL_HTTP_CODE_SPECIAL_VAR + "\n") if headers is not None: for key, value in headers.items(): cmd.append(OPTION_HEADER) cmd.append(key + ": " + value) if method is not None: cmd.append(OPTION_REQUEST) cmd.append(method) if data_file_path is not None: cmd.append(OPTION_DATA) cmd.append("@" + data_file_path) if self.insecure: cmd.append(OPTION_INSECURE) cmd.append('--verbose') cmd.append(url) return cmd def issue_request(self, url, headers, method, data): data_file_path = None headers = self.merge_headers(self.default_headers, headers) # if a body is included, write it to a temporary file (prevent body from leaking in ps/top) if method != self.GET and data is not None: serialized_data = self.json.dumps(data) # write data to disk data_file_name = base64.standard_b64encode(str(time.time()) + str(random.randint(0, sys.maxsize)) + str(random.randint(0, sys.maxsize)) + str(random.randint(0, sys.maxsize)) + str(random.randint(0, sys.maxsize))) data_file_path = os.path.join("/tmp", data_file_name) f = open(data_file_path, "wb") f.write(serialized_data) f.close() # insert Content-Type header headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) # ** nesting of try statement is required since try/except/finally isn't supported prior to 2.5 ** try: try: cmd = self.build_request_cmd(url, headers, method=method, data_file_path=data_file_path) env = os.environ.copy() p = subprocessfactory.create_subprocess(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() if p.returncode != EXIT_SUCCESS: raise Exception("Http request failed due to curl error. [returncode=" + str(p.returncode) + "]" + "[stderr=" + str(err) + "]") return self.parse_raw_output(out) except Exception as e: raise Exception("Unknown exception while issuing request. [exception=" + str(e) + "]" + "[stacktrace=" + str(traceback.format_exc()) + "]") finally: if data_file_path is not None: os.remove(data_file_path) def get(self, url, headers=None, data=None): """Issues a GET request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ return self.issue_request(url, headers, self.GET, data) def post(self, url, headers=None, data=None): """Issues a POST request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ return self.issue_request(url, headers, self.POST, data) def put(self, url, headers=None, data=None): """Issues a PUT request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ return self.issue_request(url, headers, self.PUT, data) def delete(self, url, headers=None, data=None): """Issues a DELETE request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ return self.issue_request(url, headers, self.DELETE, data) ================================================ FILE: DSC/dsc.py ================================================ #!/usr/bin/env python # # DSC extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import re import subprocess import sys import traceback try: from urllib.parse import urlparse, urlencode from urllib.request import urlopen, Request from urllib.error import HTTPError except ImportError: from urlparse import urlparse from urllib import urlencode from urllib2 import urlopen, Request, HTTPError import time import platform import json import datetime import serializerfactory import httpclient import urllib2httpclient import urllib3httpclient import httpclientfactory from azure.storage import BlobService from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util # Define global variables ExtensionName = 'Microsoft.OSTCExtensions.DSCForLinux' ExtensionShortName = 'DSCForLinux' DownloadDirectory = 'download' omi_package_prefix = 'packages/omi-1.7.3-0.ssl_' dsc_package_prefix = 'packages/dsc-1.2.4-0.ssl_' omi_major_version = 1 omi_minor_version = 7 omi_build = 3 omi_release = 0 dsc_major_version = 1 dsc_minor_version = 2 dsc_build = 4 dsc_release = 0 package_pattern = '(\d+).(\d+).(\d+).(\d+)' nodeid_path = '/etc/opt/omi/conf/dsc/agentid' date_time_format = "%Y-%m-%dT%H:%M:%SZ" extension_handler_version = "3.0.0.6" python_command = 'python3' if sys.version_info >= (3,0) else 'python' dsc_script_path = '/opt/microsoft/dsc/Scripts/python3' if sys.version_info >= (3,0) else '/opt/microsoft/dsc/Scripts' space_string = " " # Error codes UnsupportedDistro = 51 #excludes from SLA DPKGLockedErrorCode = 51 #excludes from SLA # DSC-specific Operation class Operation: Download = "Download" ApplyMof = "ApplyMof" ApplyMetaMof = "ApplyMetaMof" InstallModule = "InstallModule" RemoveModule = "RemoveModule" Register = "Register" Enable = "Enable" class DistroCategory: debian = 1 redhat = 2 suse = 3 class Mode: push = "push" pull = "pull" install = "install" remove = "remove" register = "register" def main(): waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." % (ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.try_parse_context() global public_settings public_settings = hutil.get_public_settings() if not public_settings: waagent.AddExtensionEvent(name=ExtensionShortName, op='MainInProgress', isSuccess=True, message="Public settings are NOT provided.") public_settings = {} global protected_settings protected_settings = hutil.get_protected_settings() if not protected_settings: waagent.AddExtensionEvent(name=ExtensionShortName, op='MainInProgress', isSuccess=True, message="protected settings are NOT provided.") protected_settings = {} global distro_category vm_supported, vm_dist, vm_ver = check_supported_OS() distro_category = get_distro_category(vm_dist.lower(), vm_ver.lower()) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() def get_distro_category(distro_name,distro_version): if distro_name.startswith('ubuntu') or (distro_name.startswith('debian')): return DistroCategory.debian elif distro_name.startswith('centos') or distro_name.startswith('redhat') or distro_name.startswith('oracle') or distro_name.startswith('red hat'): return DistroCategory.redhat elif distro_name.startswith('suse') or distro_name.startswith('sles'): return DistroCategory.suse waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Unsupported distro :" + distro_name + "; distro_version: " + distro_version) hutil.do_exit(UnsupportedDistro, 'Install', 'error', str(UnsupportedDistro), distro_name + 'is not supported.') def check_supported_OS(): """ Checks if the VM this extension is running on is supported by DSC Returns for platform.linux_distribution() vary widely in format, such as '7.3.1611' returned for a VM with CentOS 7, so the first provided digits must match. All other distros not supported will get error code 51 """ supported_dists = {'redhat' : ['7', '8'], # CentOS 'centos' : ['7', '8'], # CentOS 'red hat' : ['7', '8'], # Redhat 'debian' : ['8', '9', '10'], # Debian 'ubuntu' : ['14.04', '16.04', '18.04', '20.04'], # Ubuntu 'oracle' : ['7'], # Oracle 'suse' : ['12', '15'], #SLES 'sles' : ['12', '15'] } vm_dist, vm_ver, vm_supported = '', '', False try: vm_dist, vm_ver, vm_id = platform.linux_distribution() except AttributeError: try: vm_dist, vm_ver, vm_id = platform.dist() except: waagent.Log("Falling back to /etc/os-release distribution parsing") # Fallback if either of the above fail; on some (especially newer) # distros, linux_distribution() and dist() are unreliable or deprecated if not vm_dist and not vm_ver: try: with open('/etc/os-release', 'r') as fp: for line in fp: if line.startswith('ID='): vm_dist = line.split('=')[1] vm_dist = vm_dist.split('-')[0] vm_dist = vm_dist.replace('\"', '').replace('\n', '') elif line.startswith('VERSION_ID='): vm_ver = line.split('=')[1] vm_ver = vm_ver.replace('\"', '').replace('\n', '') except: waagent.Log('Indeterminate operating system') return vm_supported, 'Indeterminate operating system', '' # Find this VM distribution in the supported list for supported_dist in supported_dists.keys(): if vm_dist.lower().startswith(supported_dist): # Check if this VM distribution version is supported vm_ver_split = vm_ver.split('.') for supported_ver in supported_dists[supported_dist]: supported_ver_split = supported_ver.split('.') # If vm_ver is at least as precise (at least as many digits) as # supported_ver and matches all the supported_ver digits, then # this VM is supported vm_ver_match = True for idx, supported_ver_num in enumerate(supported_ver_split): try: supported_ver_num = int(supported_ver_num) vm_ver_num = int(vm_ver_split[idx]) except IndexError: vm_ver_match = False break if vm_ver_num is not supported_ver_num: vm_ver_match = False break if vm_ver_match: vm_supported = True break if not vm_supported: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Unsupported OS :" + vm_dist + "; distro_version: " + vm_ver) hutil.do_exit(UnsupportedDistro, 'Install', 'error', str(UnsupportedDistro), vm_dist + "; distro_version: " + vm_ver + ' is not supported.') return vm_supported, vm_dist, vm_ver def install(): hutil.do_parse_context('Install') try: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Installing DSCForLinux extension") remove_old_dsc_packages() install_dsc_packages() waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="successfully installed DSCForLinux extension") hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="failed to install DSC extension with error: {0} and stacktrace: {1}".format( str(e), traceback.format_exc())) hutil.error( "Failed to install DSC extension with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '1', 'Install Failed.') def enable(): hutil.do_parse_context('Enable') hutil.exit_if_enabled() try: start_omiservice() mode = get_config('Mode') if mode == '': mode = get_config('ExtensionAction') waagent.AddExtensionEvent(name=ExtensionShortName, op='EnableInProgress', isSuccess=True, message="Enabling the DSC extension - mode/ExtensionAction: " + mode) if mode == '': mode = Mode.push else: mode = mode.lower() if not hasattr(Mode, mode): waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Enable, isSuccess=True, message="(03001)Argument error, invalid ExtensionAction/mode.") hutil.do_exit(51, 'Enable', 'error', '51', 'Enable failed, unknown ExtensionAction/mode: ' + mode) if mode == Mode.remove: remove_module() elif mode == Mode.register: registration_key = get_config('RegistrationKey') registation_url = get_config('RegistrationUrl') # Optional node_configuration_name = get_config('NodeConfigurationName') refresh_freq = get_config('RefreshFrequencyMins') configuration_mode_freq = get_config('ConfigurationModeFrequencyMins') configuration_mode = get_config('ConfigurationMode') exit_code, err_msg = register_automation(registration_key, registation_url, node_configuration_name, refresh_freq, configuration_mode_freq, configuration_mode.lower()) if exit_code != 0: hutil.do_exit(exit_code, 'Enable', 'error', str(exit_code), err_msg) extension_status_event = "ExtensionRegistration" response = send_heart_beat_msg_to_agent_service(extension_status_event) status_file_path, agent_id, vm_uuid = get_status_message_details() update_statusfile(status_file_path, agent_id, vm_uuid, response) sys.exit(0) else: file_path = download_file() if mode == Mode.pull: current_config = apply_dsc_meta_configuration(file_path) elif mode == Mode.push: current_config = apply_dsc_configuration(file_path) else: install_module(file_path) if mode == Mode.push or mode == Mode.pull: if check_dsc_configuration(current_config): if mode == Mode.push: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMof, isSuccess=True, message="(03104)Succeeded to apply MOF configuration through Push Mode") else: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMetaMof, isSuccess=True, message="(03106)Succeeded to apply meta MOF configuration through Pull Mode") extension_status_event = "ExtensionRegistration" response = send_heart_beat_msg_to_agent_service(extension_status_event) status_file_path, agent_id, vm_uuid = get_status_message_details() update_statusfile(status_file_path, agent_id, vm_uuid, response) sys.exit(0) else: if mode == Mode.push: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMof, isSuccess=False, message="(03105)Failed to apply MOF configuration through Push Mode") else: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMetaMof, isSuccess=False, message="(03107)Failed to apply meta MOF configuration through Pull Mode") hutil.do_exit(1, 'Enable', 'error', '1', 'Enable failed. ' + current_config) hutil.do_exit(0, 'Enable', 'success', '0', 'Enable Succeeded') except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='EnableInProgress', isSuccess=True, message="Enable failed with the error: {0}, stacktrace: {1} ".format(str(e), traceback.format_exc())) hutil.error('Failed to enable the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '1', 'Enable failed: {0}'.format(e)) def send_heart_beat_msg_to_agent_service(status_event_type): response = None try: retry_count = 0 canRetry = True while retry_count <= 5 and canRetry: waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="In send_heart_beat_msg_to_agent_service method") code, output, stderr = run_cmd( python_command + space_string + dsc_script_path + "/GetDscLocalConfigurationManager.py") if code == 0 and "RefreshMode=Pull" in output: waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="sends heartbeat message in pullmode") m = re.search("ServerURL=([^\n]+)", output) if not m: return registration_url = m.group(1) agent_id = get_nodeid(nodeid_path) node_extended_properties_url = registration_url + "/Nodes(AgentId='" + agent_id + "')/ExtendedProperties" waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="Url is " + node_extended_properties_url) headers = {'Content-Type': "application/json; charset=utf-8", 'Accept': "application/json", "ProtocolVersion": "2.0"} data = construct_node_extension_properties(output, status_event_type) http_client_factory = httpclientfactory.HttpClientFactory("/etc/opt/omi/ssl/oaas.crt", "/etc/opt/omi/ssl/oaas.key") http_client = http_client_factory.create_http_client(sys.version_info) response = http_client.post(node_extended_properties_url, headers=headers, data=data) waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="response code is " + str(response.status_code)) if response.status_code >= 500 and response.status_code < 600: canRetry = True time.sleep(10) else: canRetry = False retry_count += 1 except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="Failed to send heartbeat message to DSC agent service: {0}, stacktrace: {1} ".format( str(e), traceback.format_exc())) hutil.error('Failed to send heartbeat message to DSC agent service: %s, stack trace: %s' % ( str(e), traceback.format_exc())) return response def get_lcm_config_setting(setting_name, lcmconfig): valuegroup = re.search(setting_name + "=([^\n]+)", lcmconfig) if not valuegroup: return "" value = valuegroup.group(1) return value def construct_node_extension_properties(lcmconfig, status_event_type): waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="Getting properties") OMSCLOUD_ID = get_omscloudid() vm_dist, vm_ver, vm_id = '', '', '' try: vm_dist, vm_ver, vm_id = platform.linux_distribution() except AttributeError: try: vm_dist, vm_ver, vm_id = platform.dist() except AttributeError: waagent.Log("Falling back to /etc/os-release distribution parsing") # Fallback if either of the above fail; on some (especially newer) # distros, linux_distribution() and dist() are unreliable or deprecated if not vm_dist and not vm_ver: try: with open('/etc/os-release', 'r') as fp: for line in fp: if line.startswith('ID='): vm_dist = line.split('=')[1] vm_dist = vm_dist.split('-')[0] vm_dist = vm_dist.replace('\"', '').replace('\n', '') elif line.startswith('VERSION_ID='): vm_ver = line.split('=')[1] vm_ver = vm_ver.replace('\"', '').replace('\n', '') except: waagent.Log('Indeterminate operating system') vm_dist, vm_ver, vm_id = "Indeterminate operating system", "","" if len(vm_ver.split('.')) == 1: major_version = vm_ver.split('.')[0] minor_version = 0 if len(vm_ver.split('.')) >= 2: major_version = vm_ver.split('.')[0] minor_version = vm_ver.split('.')[1] VMUUID = get_vmuuid() node_config_names = get_lcm_config_setting('ConfigurationNames', lcmconfig) configuration_mode = get_lcm_config_setting("ConfigurationMode", lcmconfig) configuration_mode_frequency = get_lcm_config_setting("ConfigurationModeFrequencyMins", lcmconfig) refresh_frequency_mins = get_lcm_config_setting("RefreshFrequencyMins", lcmconfig) reboot_node = get_lcm_config_setting("RebootNodeIfNeeded", lcmconfig) action_after_reboot = get_lcm_config_setting("ActionAfterReboot", lcmconfig) allow_module_overwrite = get_lcm_config_setting("AllowModuleOverwrite", lcmconfig) waagent.AddExtensionEvent(name=ExtensionShortName, op='HeartBeatInProgress', isSuccess=True, message="Constructing properties data") properties_data = { "OMSCloudId": OMSCLOUD_ID, "TimeStamp": time.strftime(date_time_format, time.gmtime()), "VMResourceId": "", "ExtensionStatusEvent": status_event_type, "ExtensionInformation": { "Name": "Microsoft.OSTCExtensions.DSCForLinux", "Version": extension_handler_version }, "OSProfile": { "Name": vm_dist, "Type": "Linux", "MinorVersion": minor_version, "MajorVersion": major_version, "VMUUID": VMUUID }, "RegistrationMetaData": { "NodeConfigurationName": node_config_names, "ConfigurationMode": configuration_mode, "ConfigurationModeFrequencyMins": configuration_mode_frequency, "RefreshFrequencyMins": refresh_frequency_mins, "RebootNodeIfNeeded": reboot_node, "ActionAfterReboot": action_after_reboot, "AllowModuleOverwrite": allow_module_overwrite } } return properties_data def uninstall(): hutil.do_parse_context('Uninstall') try: extension_status_event = "ExtensionUninstall" send_heart_beat_msg_to_agent_service(extension_status_event) hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded') except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='UninstallInProgress', isSuccess=False, message='Failed to uninstall the extension with error: %s, stack trace: %s' % ( str(e), traceback.format_exc())) hutil.error( 'Failed to uninstall the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc())) hutil.do_exit(1, 'Uninstall', 'error', '1', 'Uninstall failed: {0}'.format(e)) def disable(): hutil.do_parse_context('Disable') hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded') def update(): hutil.do_parse_context('Update') try: extension_status_event = "ExtensionUpgrade" send_heart_beat_msg_to_agent_service(extension_status_event) hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded') except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='UpdateInProgress', isSuccess=False, message='Failed to update the extension with error: %s, stack trace: %s' % ( str(e), traceback.format_exc())) hutil.error('Failed to update the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc())) hutil.do_exit(1, 'Update', 'error', '1', 'Update failed: {0}'.format(e)) def run_cmd(cmd): proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True) exit_code = proc.wait() stdout, stderr = proc.communicate() stdout = stdout.decode("ISO-8859-1") if isinstance(stdout, bytes) else stdout stderr = stderr.decode("ISO-8859-1") if isinstance(stderr, bytes) else stderr return exit_code, stdout, stderr def run_dpkg_cmd_with_retry(cmd): """ Attempts to run the cmd - if it fails, checks to see if dpkg is locked by another process, if so, it will sleep for 5 seconds and then try running the command again. If dpkg is still locked, then it will return the DPKGLockedErrorCode which won't count against our SLA numbers. """ exit_code, output, stderr = run_cmd(cmd) if not exit_code == 0: dpkg_locked = is_dpkg_locked(exit_code, stderr) if dpkg_locked: # Try one more time: time.sleep(5) exit_code, output, stderr = run_cmd(cmd) dpkg_locked = is_dpkg_locked(exit_code, stderr) if dpkg_locked: exit_code = DPKGLockedErrorCode return exit_code, output, stderr def get_config(key): if key in public_settings: value = public_settings.get(key) if value: return str(value).strip() if key in protected_settings: value = protected_settings.get(key) if value: return str(value).strip() return '' def remove_old_dsc_packages(): waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Deleting DSC and omi packages") if distro_category == DistroCategory.debian: deb_remove_incomptible_dsc_package() # remove the package installed by Linux DSC 1.0, in later versions the package name is changed to 'omi' deb_remove_old_oms_package('omiserver', '1.0.8.2') elif distro_category == DistroCategory.redhat or distro_category == DistroCategory.suse: rpm_remove_incomptible_dsc_package() # remove the package installed by Linux DSC 1.0, in later versions the package name is changed to 'omi' rpm_remove_old_oms_package('omiserver', '1.0.8-2') def deb_remove_incomptible_dsc_package(): version = deb_get_pkg_version('dsc') if version is not None and is_incomptible_dsc_package(version): deb_uninstall_package('dsc') def is_incomptible_dsc_package(package_version): version = re.match(package_pattern, package_version) # uninstall DSC package if the version is 1.0.x because upgrading from 1.0 to 1.1 is broken if version is not None and (int(version.group(1)) == 1 and int(version.group(2)) == 0): return True return False def is_old_oms_server(package_name): if package_name == 'omiserver': return True return False def deb_remove_old_oms_package(package_name, version): system_pkg_version = deb_get_pkg_version(package_name) if system_pkg_version is not None and is_old_oms_server(package_name): deb_uninstall_package(package_name) def deb_get_pkg_version(package_name): code, output, stderr = run_dpkg_cmd_with_retry('dpkg -s ' + package_name + ' | grep Version:') if code == 0: code, output, stderr = run_dpkg_cmd_with_retry("dpkg -s " + package_name + " | grep Version: | awk '{print $2}'") if code == 0: return output def rpm_remove_incomptible_dsc_package(): code, version, stderr = run_cmd('rpm -q --queryformat "%{VERSION}.%{RELEASE}" dsc') if code == 0 and is_incomptible_dsc_package(version): rpm_uninstall_package('dsc') def rpm_remove_old_oms_package(package_name, version): if rpm_check_old_oms_package(package_name, version): rpm_uninstall_package(package_name) def rpm_check_old_oms_package(package_name, version): code, output, stderr = run_cmd('rpm -q ' + package_name) if code == 0 and is_old_oms_server(package_name): return True return False def install_dsc_packages(): openssl_version = get_openssl_version() omi_package_path = omi_package_prefix + openssl_version dsc_package_path = dsc_package_prefix + openssl_version compiler_mitigated_omi_flag = get_compiler_mitigated_omi_flag() waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Installing omipackage version: " + omi_package_path + "; dsc package version: " + dsc_package_path) if distro_category == DistroCategory.debian: deb_install_pkg(omi_package_path + '.ulinux' + compiler_mitigated_omi_flag + '.x64.deb', 'omi', omi_major_version, omi_minor_version, omi_build, omi_release, ' --force-confold --force-confdef --refuse-downgrade ') deb_install_pkg(dsc_package_path + '.x64.deb', 'dsc', dsc_major_version, dsc_minor_version, dsc_build, dsc_release, '') elif distro_category == DistroCategory.redhat or distro_category == DistroCategory.suse: rpm_install_pkg(omi_package_path + '.ulinux' + compiler_mitigated_omi_flag + '.x64.rpm', 'omi', omi_major_version, omi_minor_version, omi_build, omi_release) rpm_install_pkg(dsc_package_path + '.x64.rpm', 'dsc', dsc_major_version, dsc_minor_version, dsc_build, dsc_release) def get_compiler_mitigated_omi_flag(): vm_supported, vm_dist, vm_ver = check_supported_OS() if is_compiler_mitigated_omi_supported(vm_dist.lower(), vm_ver.lower()): return '.s' return '' def is_compiler_mitigated_omi_supported(dist_name, dist_version): # Compiler-mitigated OMI is not supported in the following # SLES 11 # To be enhanced if there are future distros not supporting compiler-mitigated OMI package if dist_name.startswith('sles') and dist_version.startswith('11'): return False return True def compare_pkg_version(system_package_version, major_version, minor_version, build, release): version = re.match(package_pattern, system_package_version) if version is not None and ((int(version.group(1)) > major_version) or ( int(version.group(1)) == major_version and int(version.group(2)) > minor_version) or ( int(version.group(1)) == major_version and int( version.group(2)) == minor_version and int(version.group(3)) > build) or ( int(version.group(1)) == major_version and int( version.group(2)) == minor_version and int(version.group(3)) == build and int( version.group(4)) >= release)): return 1 return 0 def rpm_check_pkg_exists(package_name, major_version, minor_version, build, release): code, output, stderr = run_cmd('rpm -q --queryformat "%{VERSION}.%{RELEASE}" ' + package_name) waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="package name: " + package_name + "; existing package version:" + output) hutil.log("package name: " + package_name + "; existing package version:" + output) if code == 0: return compare_pkg_version(output, major_version, minor_version, build, release) def rpm_install_pkg(package_path, package_name, major_version, minor_version, build, release): if rpm_check_pkg_exists(package_name, major_version, minor_version, build, release) == 1: # package is already installed return else: code, output, stderr = run_cmd('rpm -Uvh ' + package_path) if code == 0: hutil.log(package_name + ' is installed successfully') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Failed to install RPM package :" + package_path) raise Exception('Failed to install package {0}: stdout: {1}, stderr: {2}'.format(package_name, output, stderr)) def deb_install_pkg(package_path, package_name, major_version, minor_version, build, release, install_options): version = deb_get_pkg_version(package_name) if version is not None and compare_pkg_version(version, major_version, minor_version, build, release) == 1: # package is already installed hutil.log(package_name + ' version ' + version + ' is already installed') waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="dsc package with version: " + version + "is already installed.") return else: cmd = 'dpkg -i ' + install_options + ' ' + package_path code, output, stderr = run_dpkg_cmd_with_retry(cmd) if code == 0: hutil.log(package_name + ' version ' + str(major_version) + '.' + str(minor_version) + '.' + str( build) + '.' + str(release) + ' is installed successfully') elif code == DPKGLockedErrorCode: hutil.do_exit(DPKGLockedErrorCode, 'Install', 'error', str(DPKGLockedErrorCode), 'Install failed because the package manager on the VM is currently locked. Please try installing again.') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=False, message="Failed to install debian package :" + package_path) raise Exception('Failed to install package {0}: stdout: {1}, stderr: {2}'.format(package_name, output, stderr)) def install_package(package): if distro_category == DistroCategory.debian: apt_package_install(package) elif distro_category == DistroCategory.redhat: yum_package_install(package) elif distro_category == DistroCategory.suse: zypper_package_install(package) def zypper_package_install(package): hutil.log('zypper --non-interactive in ' + package) code, output, stderr = run_cmd('zypper --non-interactive in ' + package) if code == 0: hutil.log('Package ' + package + ' is installed successfully') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Failed to install zypper package :" + package) raise Exception('Failed to install package {0}: stdout: {1}, stderr: {2}'.format(package, output, stderr)) def yum_package_install(package): hutil.log('yum install -y ' + package) code, output, stderr = run_cmd('yum install -y ' + package) if code == 0: hutil.log('Package ' + package + ' is installed successfully') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Failed to install yum package :" + package) raise Exception('Failed to install package {0}: stdout: {1}, stderr: {2}'.format(package, output, stderr)) def apt_package_install(package): hutil.log('apt-get install -y --force-yes ' + package) code, output, stderr = run_cmd('apt-get install -y --force-yes ' + package) if code == 0: hutil.log('Package ' + package + ' is installed successfully') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="Failed to install apt package :" + package) raise Exception('Failed to install package {0}: stdout: {1}, stderr: {2}'.format(package, output, stderr)) def get_openssl_version(): cmd_result = waagent.RunGetOutput("openssl version") cmd_result = cmd_result.decode() if isinstance(cmd_result, bytes) else cmd_result openssl_version = cmd_result[1].split()[1] if re.match('^1.0.*', openssl_version): return '100' elif re.match('^1.1.*', openssl_version): return '110' else: error_msg = 'This system does not have a supported version of OpenSSL installed. Supported version: 1.0.*, 1.1.*' hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="System doesn't have supported OpenSSL version:" + openssl_version) hutil.do_exit(51, 'Install', 'error', '51', openssl_version + 'is not supported.') def start_omiservice(): run_cmd('/opt/omi/bin/service_control start') code, output, stderr =run_cmd('service omid status') if code == 0: hutil.log('Service omid is started') else: raise Exception('Failed to start service omid, status: stdout: {0}, stderr: {1}'.format(output, stderr)) def download_file(): waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="Downloading file") download_dir = prepare_download_dir(hutil.get_seq_no()) storage_account_name = get_config('StorageAccountName') storage_account_key = get_config('StorageAccountKey') file_uri = get_config('FileUri') if not file_uri: error_msg = 'Missing FileUri configuration' waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Download, isSuccess=False, message="(03000)Argument error, invalid file location") hutil.do_exit(51, 'Enable', 'error', '51', '(03000)Argument error, invalid file location') if storage_account_name and storage_account_key: hutil.log('Downloading file from azure storage...') path = download_azure_blob(storage_account_name, storage_account_key, file_uri, download_dir) return path else: hutil.log('Downloading file from external link...') waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="Downloading file from external link...") path = download_external_file(file_uri, download_dir) return path def download_azure_blob(account_name, account_key, file_uri, download_dir): waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="Downloading from azure blob") try: (blob_name, container_name) = parse_blob_uri(file_uri) host_base = get_host_base_from_uri(file_uri) blob_parent_path = os.path.join(download_dir, os.path.dirname(blob_name)) if not os.path.exists(blob_parent_path): os.makedirs(blob_parent_path) download_path = os.path.join(download_dir, blob_name) blob_service = BlobService(account_name, account_key, host_base=host_base) except Exception as e: waagent.AddExtensionEvent(name=ExtensionShortName, op='DownloadInProgress', isSuccess=True, message='Enable failed with the azure storage error : {0}, stack trace: {1}'.format( str(e), traceback.format_exc())) hutil.error('Failed to enable the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '1', 'Enable failed: {0}'.format(e)) max_retry = 3 for retry in range(1, max_retry + 1): try: blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception: hutil.error('Failed to download Azure blob, retry = ' + str(retry) + ', max_retry = ' + str(max_retry)) if retry != max_retry: hutil.log('Sleep 10 seconds') time.sleep(10) else: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Download, isSuccess=False, message="(03303)Failed to download file from Azure Storage") raise Exception('Failed to download azure blob: ' + blob_name) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Download, isSuccess=True, message="(03301)Succeeded to download file from Azure Storage") return download_path def parse_blob_uri(blob_uri): path = get_path_from_uri(blob_uri).strip('/') first_sep = path.find('/') if first_sep == -1: waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=False, message="Error occured while extracting container and blob name.") hutil.error("Failed to extract container and blob name from " + blob_uri) blob_name = path[first_sep + 1:] container_name = path[:first_sep] return (blob_name, container_name) def get_path_from_uri(uri): uri = urlparse(uri) return uri.path def get_host_base_from_uri(blob_uri): uri = urlparse(blob_uri) netloc = uri.netloc if netloc is None: return None return netloc[netloc.find('.'):] def download_external_file(file_uri, download_dir): waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="Downloading from external file") path = get_path_from_uri(file_uri) file_name = path.split('/')[-1] file_path = os.path.join(download_dir, file_name) max_retry = 3 for retry in range(1, max_retry + 1): try: download_and_save_file(file_uri, file_path) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Download, isSuccess=True, message="(03302)Succeeded to download file from public URI") return file_path except Exception as e: hutil.error('Failed to download public file, retry = ' + str(retry) + ', max_retry = ' + str(max_retry)) if retry != max_retry: hutil.log('Sleep 10 seconds') time.sleep(10) else: waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Download, isSuccess=False, message='(03304)Failed to download file from public URI, error : %s, stack trace: %s' % ( str(e), traceback.format_exc())) raise Exception('Failed to download public file: ' + file_name) def download_and_save_file(uri, file_path): src = urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while (buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seq_no): main_download_dir = os.path.join(os.getcwd(), DownloadDirectory) if not os.path.exists(main_download_dir): os.makedirs(main_download_dir) cur_download_dir = os.path.join(main_download_dir, seq_no) if not os.path.exists(cur_download_dir): os.makedirs(cur_download_dir) return cur_download_dir def apply_dsc_configuration(config_file_path): cmd = dsc_script_path + '/StartDscConfiguration.py -configurationmof ' + config_file_path waagent.AddExtensionEvent(name=ExtensionShortName, op='EnableInProgress', isSuccess=True, message='running the cmd: ' + cmd) code, output, stderr = run_cmd(cmd) if code == 0: code, output, stderr = run_cmd(dsc_script_path + '/GetDscConfiguration.py') return output else: error_msg = 'Failed to apply MOF configuration: stdout: {0}, stderr: {1}'.format(output, stderr) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMof, isSuccess=True, message=error_msg) hutil.error(error_msg) raise Exception(error_msg) def apply_dsc_meta_configuration(config_file_path): cmd = dsc_script_path + '/SetDscLocalConfigurationManager.py -configurationmof ' + config_file_path waagent.AddExtensionEvent(name=ExtensionShortName, op='EnableInProgress', isSuccess=True, message='running the cmd: ' + cmd) code, output, stderr = run_cmd(cmd) if code == 0: code, output, stderr = run_cmd(dsc_script_path + '/GetDscLocalConfigurationManager.py') return output else: error_msg = 'Failed to apply Meta MOF configuration: stdout: {0}, stderr: {1}'.format(output, stderr) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.ApplyMetaMof, isSuccess=False, message="(03107)" + error_msg) raise Exception(error_msg) def get_statusfile_path(): seq_no = hutil.get_seq_no() waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="sequence number is :" + seq_no) status_file = None handlerEnvironment = None handler_env_path = os.path.join(os.getcwd(), 'HandlerEnvironment.json') try: with open(handler_env_path, 'r') as handler_env_file: handler_env_txt = handler_env_file.read() handler_env = json.loads(handler_env_txt) if type(handler_env) == list: handler_env = handler_env[0] handlerEnvironment = handler_env except Exception as e: hutil.error(e.message) waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message='exception in retrieving status_dir error : %s, stack trace: %s' % ( str(e), traceback.format_exc())) status_dir = handlerEnvironment['handlerEnvironment']['statusFolder'] status_file = status_dir + '/' + seq_no + '.status' waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="status file path: " + status_file) return status_file def get_status_message_details(): agent_id = get_nodeid(nodeid_path) vm_uuid = get_vmuuid() status_file_path = None if vm_uuid is not None and agent_id is not None: status_file_path = get_statusfile_path() return status_file_path, agent_id, vm_uuid def update_statusfile(status_filepath, node_id, vmuuid, response): waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="updating the status file " + '[statusfile={0}][vmuuid={1}][node_id={2}]'.format( status_filepath, vmuuid, node_id)) if status_filepath is None: error_msg = "Unable to locate a status file" hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=False, message=error_msg) return None status_data = None if os.path.exists(status_filepath): jsonData = open(status_filepath) status_data = json.load(jsonData) jsonData.close() accountName = response.deserialized_data["AccountName"] rgName = response.deserialized_data["ResourceGroupName"] subId = response.deserialized_data["SubscriptionId"] metadatastatus = [{"status": "success", "code": "0", "name": "metadata", "formattedMessage": {"lang": "en-US", "message": "AgentID=" + node_id + ";VMUUID=" + vmuuid + ";AutomationAccountName=" + accountName + ";ResourceGroupName=" + rgName + ";Subscription=" + subId}}] with open(status_filepath, "w") as fp: status_file_content = [{"status": {"status": "success", "formattedMessage": {"lang": "en-US", "message": "Enable Succeeded"}, "operation": "Enable", "code": "0", "name": "Microsoft.OSTCExtensions.DSCForLinux", "substatus": metadatastatus }, "version": "1.0", "timestampUTC": time.strftime(date_time_format, time.gmtime()) }] json.dump(status_file_content, fp) waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=True, message="successfully written nodeid and vmuuid") waagent.AddExtensionEvent(name=ExtensionName, op="Enable", isSuccess=True, message="successfully executed enable functionality") def get_nodeid(file_path): id = None try: if os.path.exists(file_path): with open(file_path) as f: id = f.readline().strip() except Exception as e: error_msg = 'get_nodeid() failed: Unable to open id file {0}'.format(file_path) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=False, message=error_msg) return None if not id: error_msg = 'get_nodeid() failed: Empty content in id file {0}'.format(file_path) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op="EnableInProgress", isSuccess=False, message=error_msg) return None return id def get_vmuuid(): UUID = None code, output, stderr = run_cmd("sudo dmidecode | grep UUID | sed -e 's/UUID: //'") if code == 0: UUID = output.strip() return UUID def get_omscloudid(): OMSCLOUD_ID = None code, output, stderr = run_cmd("sudo dmidecode | grep 'Tag: 77' | sed -e 's/Asset Tag: //'") if code == 0: OMSCLOUD_ID = output.strip() return OMSCLOUD_ID def check_dsc_configuration(current_config): outputlist = re.split("\n", current_config) for line in outputlist: if re.match(r'ReturnValue=0', line.strip()): return True return False def install_module(file_path): install_package('unzip') cmd = dsc_script_path + '/InstallModule.py ' + file_path code, output, stderr = run_cmd(cmd) waagent.AddExtensionEvent(name=ExtensionShortName, op="InstallModuleInProgress", isSuccess=True, message="Running the cmd: " + cmd) if not code == 0: error_msg = 'Failed to install DSC Module ' + file_path + ' stdout: {0}, stderr: {1}'.format(output, stderr) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.InstallModule, isSuccess=False, message="(03100)" + error_msg) raise Exception(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.InstallModule, isSuccess=True, message="(03101)Succeeded to install DSC Module") def remove_module(): module_name = get_config('ResourceName') cmd = dsc_script_path + '/RemoveModule.py ' + module_name code, output, stderr = run_cmd(cmd) waagent.AddExtensionEvent(name=ExtensionShortName, op="RemoveModuleInProgress", isSuccess=True, message="Running the cmd: " + cmd) if not code == 0: error_msg = 'Failed to remove DSC Module ' + module_name + ' stdout: {0}, stderr: {1}'.format(output, stderr) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.RemoveModule, isSuccess=False, message="(03102)" + error_msg) raise Exception(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.RemoveModule, isSuccess=True, message="(03103)Succeeded to remove DSC Module") def uninstall_package(package_name): waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="uninstalling the package" + package_name) if distro_category == DistroCategory.debian: deb_uninstall_package(package_name) elif distro_category == DistroCategory.redhat or distro_category == DistroCategory.suse: rpm_uninstall_package(package_name) def deb_uninstall_package(package_name): cmd = 'dpkg -P ' + package_name code, output, stderr = run_dpkg_cmd_with_retry(cmd) if code == 0: hutil.log('Package ' + package_name + ' was removed successfully') elif code == DPKGLockedErrorCode: hutil.do_exit(DPKGLockedErrorCode, 'Install', 'error', str(DPKGLockedErrorCode), 'Operation failed because the package manager on the VM is currently locked. Please try again.') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="failed to remove the package" + package_name) raise Exception('Failed to remove package ' + package_name) def rpm_uninstall_package(package_name): cmd = 'rpm -e ' + package_name code, output, stderr = run_cmd(cmd) if code == 0: hutil.log('Package ' + package_name + ' was removed successfully') else: waagent.AddExtensionEvent(name=ExtensionShortName, op='InstallInProgress', isSuccess=True, message="failed to remove the package" + package_name) raise Exception('Failed to remove package ' + package_name) def is_dpkg_locked(exit_code, output): """ If dpkg is locked, the output will contain a message similar to 'dpkg status database is locked by another process' """ if exit_code is not 0: dpkg_locked_search = r'^.*dpkg.+lock.*$' dpkg_locked_re = re.compile(dpkg_locked_search, re.M) if dpkg_locked_re.search(output): return True return False def register_automation(registration_key, registation_url, node_configuration_name, refresh_freq, configuration_mode_freq, configuration_mode): if (registration_key == '' or registation_url == ''): err_msg = "Either the Registration Key or Registration URL is NOT provided" hutil.error(err_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op='RegisterInProgress', isSuccess=True, message=err_msg) return 51, err_msg if configuration_mode != '' and not ( configuration_mode == 'applyandmonitor' or configuration_mode == 'applyandautocorrect' or configuration_mode == 'applyonly'): err_msg = "ConfigurationMode: " + configuration_mode + " is not valid." hutil.error(err_msg + "It should be one of the values : (ApplyAndMonitor | ApplyAndAutoCorrect | ApplyOnly)") waagent.AddExtensionEvent(name=ExtensionShortName, op='RegisterInProgress', isSuccess=True, message=err_msg) return 51, err_msg cmd = dsc_script_path + '/Register.py' + ' --RegistrationKey ' + registration_key \ + ' --ServerURL ' + registation_url optional_parameters = "" if node_configuration_name != '': optional_parameters += ' --ConfigurationName ' + node_configuration_name if refresh_freq != '': optional_parameters += ' --RefreshFrequencyMins ' + refresh_freq if configuration_mode_freq != '': optional_parameters += ' --ConfigurationModeFrequencyMins ' + configuration_mode_freq if configuration_mode != '': optional_parameters += ' --ConfigurationMode ' + configuration_mode waagent.AddExtensionEvent(name=ExtensionShortName, op="RegisterInProgress", isSuccess=True, message="Registration URL " + registation_url + "Optional parameters to Registration" + optional_parameters) code, output, stderr = run_cmd(cmd + optional_parameters) if not code == 0: error_msg = '(03109)Failed to register with Azure Automation DSC: stdout: {0}, stderr: {1}'.format(output, stderr) hutil.error(error_msg) waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Register, isSuccess=False, message=error_msg) return 1, error_msg waagent.AddExtensionEvent(name=ExtensionShortName, op=Operation.Register, isSuccess=True, message="(03108)Succeeded to register with Azure Automation DSC") return 0, '' if __name__ == '__main__': main() ================================================ FILE: DSC/extension_shim.sh ================================================ #!/usr/bin/env bash # Keeping the default command COMMAND="" PYTHON="" # We are writing logs to error stream in extension_shim.sh as the logs written to output stream are being overriden by HandlerUtil.py. This has been done as part of OMIGOD hotfix # Default variables for OMI Package Upgrade REQUIRED_OMI_VERSION="1.7.3.0" INSTALLED_OMI_VERSION="" UPGRADED_OMI_VERSION="" OPENSSL_VERSION="" OMI_PACKAGE_PREFIX='packages/omi-1.7.3-0.ssl_' OMI_PACKAGE_PATH="" OMI_SERVICE_STATE="" USAGE="$(basename "$0") [-h] [-i|--install] [-u|--uninstall] [-d|--disable] [-e|--enable] [-p|--update] Program to find the installed python on the box and invoke a Python extension script. where: -h|--help show this help text -i|--install install the extension -u|--uninstall uninstall the extension -d|--disable disable the extension -e|--enable enable the extension -p|--update update the extension -c|--command command to run example: # Install usage $ bash extension_shim.sh -i python ./vmaccess.py -install # Custom executable python file $ bash extension_shim.sh -c ""hello.py"" -i python hello.py -install # Custom executable python file with arguments $ bash extension_shim.sh -c ""hello.py --install"" python hello.py --install " function find_python(){ local python_exec_command=$1 # Check if there is python2 defined. if command -v python2 >/dev/null 2>&1 ; then eval ${python_exec_command}="python2" else # Python2 was not found. Searching for Python3 now. if command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" fi fi } function get_openssl_version(){ openssl=`openssl version | awk '{print $2}'` if [[ ${openssl} =~ ^1.0.* ]]; then OPENSSL_VERSION="100" else if [[ ${openssl} =~ ^1.1.* ]]; then OPENSSL_VERSION="110" else if [[ ${openssl} =~ ^0.9.8* ]]; then OPENSSL_VERSION="098" fi fi fi } function start_omiservice(){ echo "Attempting to start OMI service" >&2 RESULT=`/opt/omi/bin/service_control start >/dev/null 2>&1` RESULT=`service omid status >/dev/null 2>&1` if [ $? -eq 0 ]; then echo "OMI service succesfully started." >&2 else echo "OMI service could not be started." >&2 fi } function stop_omiservice(){ echo "Attempting to stop OMI service" >&2 RESULT=`/opt/omi/bin/service_control stop >/dev/null 2>&1` RESULT=`service omid status >/dev/null 2>&1` if [ $? -eq 3 ]; then echo "OMI service succesfully stopped." >&2 else echo "OMI service could not be stopped." >&2 fi } function compare_versions(){ if [[ $1 == $2 ]] then return 0 fi local IFS=. local i v1=($1) v2=($2) for ((i=0; i<${#v1[@]}; i++)) do if ((${v1[i]} > ${v2[i]})) then return 1 fi if ((${v1[i]} < ${v2[i]})) then return 2 fi done return 0 } function get_compiler_mitigated_flag() { OS_NAME=`grep '^NAME' /etc/os-release | tr -d 'NAME=' | tr -d '"' | tr '[:upper:]' '[:lower:]'` echo "OS: ${OS_NAME}" >&2 OS_VERSION=`grep '^VERSION_ID' /etc/os-release | tr -d 'VERSION_ID=' | tr -d '"' | tr '[:upper:]' '[:lower:]'` echo "OS VERSION: ${OS_VERSION}" >&2 # Compiler-mitigated OMI is not supported in the following # SLES 11 # To be enhanced if there are future distros not supporting compiler-mitigated OMI package FLAG="" if [[ $OS_NAME == sles* && $OS_VERSION == 11* ]] then FLAG="" else FLAG=".s" fi echo $FLAG } function ensure_required_omi_version_exists(){ # Populate SSL Version get_openssl_version echo "Checking if OMI is installed. Required OMI version: ${REQUIRED_OMI_VERSION};" >&2 COMPILER_MITIGATED_VERSION_FLAG=$( get_compiler_mitigated_flag ) echo "OMI compiler-mitigated flag: (${COMPILER_MITIGATED_VERSION_FLAG})" >&2 # Check if RPM exists if command -v rpm >/dev/null 2>&1 ; then echo "Package Manager Type: RPM" >&2 INSTALLED_OMI_VERSION=`rpm -q --queryformat "%{VERSION}.%{RELEASE}" omi 2>&1` if [ -z "$INSTALLED_OMI_VERSION" -o "$INSTALLED_OMI_VERSION" = "package omi is not installed" ]; then echo "OMI is not installed on the machine." >&2 else RESULT=`service omid status >/dev/null 2>&1` OMI_SERVICE_STATE=$? echo "OMI is already installed. Installed OMI version: ${INSTALLED_OMI_VERSION}; OMI Service State: ${OMI_SERVICE_STATE};" >&2 # Add current running status compare_versions ${INSTALLED_OMI_VERSION} ${REQUIRED_OMI_VERSION} if [ $? -eq 2 ]; then OMI_PACKAGE_PATH="${OMI_PACKAGE_PREFIX}${OPENSSL_VERSION}.ulinux${COMPILER_MITIGATED_VERSION_FLAG}.x64.rpm" echo "Installed OMI version is lower than the Required OMI version. Trying to upgrade." >&2 if [ -f ${OMI_PACKAGE_PATH} ]; then echo "The OMI package exists at ${OMI_PACKAGE_PATH}. Using this to upgrade." >&2 stop_omiservice RESULT=`rpm -Uvh ${OMI_PACKAGE_PATH} >/dev/null 2>&1` if [ $? -eq 0 ]; then UPGRADED_OMI_VERSION=`rpm -q --queryformat "%{VERSION}.%{RELEASE}" omi 2>&1` echo "Succesfully upgraded the OMI. Installed: ${INSTALLED_OMI_VERSION}; Required: ${REQUIRED_OMI_VERSION}; Upgraded: ${UPGRADED_OMI_VERSION};" >&2 else echo "Failed to upgrade the OMI. Installed: ${INSTALLED_OMI_VERSION}; Required: ${REQUIRED_OMI_VERSION};" >&2 fi # Start OMI only if previous state was running if [ $OMI_SERVICE_STATE -eq 0 ]; then start_omiservice fi else echo "The OMI package does not exists at ${OMI_PACKAGE_PATH}. Skipping upgrade." >&2 fi else echo "Installed OMI version is equal to or greater than the Required OMI version. No action needed." >&2 fi fi INSTALLED_OMI_VERSION=`rpm -q --queryformat "%{VERSION}.%{RELEASE}" omi 2>&1` RESULT=`service omid status >/dev/null 2>&1` OMI_SERVICE_STATE=$? echo "OMI upgrade is complete. Installed OMI version: ${INSTALLED_OMI_VERSION}; OMI Service State: ${OMI_SERVICE_STATE};" >&2 else # Check if DPKG exists if command -v dpkg >/dev/null 2>&1 ; then echo "Package Manager Type: DPKG" >&2 INSTALLED_OMI_VERSION=`dpkg -s omi 2>&1 | grep Version: | awk '{print $2}'` if [ -z "$INSTALLED_OMI_VERSION" -o "$INSTALLED_OMI_VERSION" = "package omi is not installed" ]; then echo "OMI is not installed on the machine." >&2 else RESULT=`service omid status >/dev/null 2>&1` OMI_SERVICE_STATE=$? echo "OMI is already installed. Installed OMI version: ${INSTALLED_OMI_VERSION}; OMI Service State: ${OMI_SERVICE_STATE};" >&2 compare_versions ${INSTALLED_OMI_VERSION} ${REQUIRED_OMI_VERSION} if [ $? -eq 2 ]; then OMI_PACKAGE_PATH="${OMI_PACKAGE_PREFIX}${OPENSSL_VERSION}.ulinux${COMPILER_MITIGATED_VERSION_FLAG}.x64.deb" echo "Installed OMI version is lower than the Required OMI version. Trying to upgrade." >&2 if [ -f ${OMI_PACKAGE_PATH} ]; then echo "The OMI package exists at ${OMI_PACKAGE_PATH}. Using this to upgrade." >&2 stop_omiservice RESULT=`dpkg -i --force-confold --force-confdef --refuse-downgrade ${OMI_PACKAGE_PATH} >/dev/null 2>&1` if [ $? -eq 0 ]; then UPGRADED_OMI_VERSION=`dpkg -s omi 2>&1 | grep Version: | awk '{print $2}'` echo "Succesfully upgraded the OMI. Installed: ${INSTALLED_OMI_VERSION}; Required: ${REQUIRED_OMI_VERSION}; Upgraded: ${UPGRADED_OMI_VERSION};" >&2 else echo "Failed to upgrade the OMI. Installed: ${INSTALLED_OMI_VERSION}; Required: ${REQUIRED_OMI_VERSION};" >&2 fi # Start OMI only if previous state was running if [ $OMI_SERVICE_STATE -eq 0 ]; then start_omiservice fi else echo "The OMI package does not exists at ${OMI_PACKAGE_PATH}. Skipping upgrade." >&2 fi else echo "Installed OMI version is equal to or greater than the Required OMI version. No action needed." >&2 fi fi INSTALLED_OMI_VERSION=`dpkg -s omi 2>&1 | grep Version: | awk '{print $2}'` RESULT=`service omid status >/dev/null 2>&1` OMI_SERVICE_STATE=$? echo "OMI upgrade is complete. Installed OMI version: ${INSTALLED_OMI_VERSION}; OMI Service State: ${OMI_SERVICE_STATE};" >&2 fi fi } # Transform long options to short ones for getopts support (getopts doesn't support long args) for arg in "$@"; do shift case "$arg" in "--help") set -- "$@" "-h" ;; "--install") set -- "$@" "-i" ;; "--update") set -- "$@" "-p" ;; "--enable") set -- "$@" "-e" ;; "--disable") set -- "$@" "-d" ;; "--uninstall") set -- "$@" "-u" ;; *) set -- "$@" "$arg" esac done if [ -z "$arg" ] then echo "$USAGE" >&2 exit 1 fi # Get the arguments while getopts "iudephc:?" o; do case "${o}" in h|\?) echo "$USAGE" exit 0 ;; i) operation="-install" ;; u) operation="-uninstall" ;; d) operation="-disable" ;; e) operation="-enable" ;; p) operation="-update" ;; c) COMMAND="$OPTARG" ;; *) echo "$USAGE" >&2 exit 1 ;; esac done shift $((OPTIND-1)) # Ensure OMI package if exists is of required version. ensure_required_omi_version_exists # If find_python is not able to find a python installed, $PYTHON will be null. find_python PYTHON if [ -z "$PYTHON" ]; then echo "No Python interpreter found on the box" >&2 exit 51 # Not Supported else `${PYTHON} --version` fi ${PYTHON} ${COMMAND} ${operation} # DONE ================================================ FILE: DSC/httpclient.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """HttpClient base class.""" import os import sys import serializerfactory class HttpClient: """Base class to provide common attributes and functionality to all HttpClient implementation.""" ACCEPT_HEADER_KEY = "Accept" CONTENT_TYPE_HEADER_KEY = "Content-Type" CONNECTION_HEADER_KEY = "Connection" USER_AGENT_HEADER_KEY = "User-Agent" APP_JSON_HEADER_VALUE = "application/json" KEEP_ALIVE_HEADER_VALUE = "keep-alive" GET = "GET" POST = "POST" PUT = "PUT" DELETE = "DELETE" def __init__(self, cert_path, key_path, insecure=False, proxy_configuration=None): self.cert_path = cert_path self.key_path = key_path self.insecure = insecure self.proxy_configuration = proxy_configuration # validate presence of cert/key in case they were removed after process creation if (cert_path is not None and not os.path.isfile(self.cert_path)) or \ (key_path is not None and not os.path.isfile(self.key_path)): print(cert_path) raise Exception("Invalid certificate or key file path.") self.default_headers = {self.ACCEPT_HEADER_KEY: self.APP_JSON_HEADER_VALUE, self.CONNECTION_HEADER_KEY: self.KEEP_ALIVE_HEADER_VALUE } self.json = serializerfactory.get_serializer(sys.version_info) @staticmethod def merge_headers(client_headers, request_headers): """Merges client_headers and request_headers into a single dictionary. If a request_header key is also present in the client_headers, the request_header value will override the client_header one. Args: client_headers : dictionary, the default client's headers. request_headers : dictionary, request specific headers. Returns: A dictionary containing a set of both the client_headers and the request_headers """ if request_headers is not None: client_headers.update(request_headers.copy()) else: request_headers = client_headers.copy() return request_headers def get(self, url, headers=None): """Issues a GET request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). Returns: A RequestResponse """ pass def post(self, url, headers=None, data=None): """Issues a POST request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ pass def put(self, url, headers=None, data=None): """Issues a PUT request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ pass def delete(self, url, headers=None, data=None): """Issues a DELETE request to the provided url using the provided headers. Args: url : string , the URl. headers : dictionary, contains the headers key value pair (defaults to None). data : dictionary, contains the non-serialized request body (defaults to None). Returns: A RequestResponse """ pass class RequestResponse: """Encapsulates all request response for http clients. Will also deserialize the response when the raw response data is deserializable. """ def __init__(self, status_code, raw_response_data=None): self.status_code = int(status_code) self.raw_data = raw_response_data self.json = serializerfactory.get_serializer(sys.version_info) if raw_response_data is not None: try: self.deserialized_data = self.json.loads(self.raw_data) except ValueError: self.deserialized_data = None ================================================ FILE: DSC/httpclientfactory.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. import os from curlhttpclient import CurlHttpClient PY_MAJOR_VERSION = 0 PY_MINOR_VERSION = 1 PY_MICRO_VERSION = 2 class HttpClientFactory: """Factory which returns the appropriate HttpClient based on the provided python version. Targets : [2.4.0 - 2.7.9[ : CurlHttpclient [2.7.9 - 2.7.9+ : Urllib2Httpclient 3.0+ : Urllib3Httpclient This is due to the lack of built-in strict certificate verification prior to 2.7.9. The ssl module was also unavailable for [2.4.0 - 2.6.0[. """ def __init__(self, cert, key, insecure=False): self.cert = cert self.key = key self.insecure = insecure self.proxy_configuration = None def create_http_client(self, version_info): """Create a new instance of the appropriate HttpClient. Args: version_info : array, the build-in python version_info array. insecure : bool, when set to True, httpclient wil bypass certificate verification. Returns: An instance of CurlHttpClient if the installed Python version is below 2.7.9 An instance of Urllib2 if the installed Python version is or is above 2.7.9 """ if version_info[PY_MAJOR_VERSION] == 3: from urllib3httpclient import Urllib3HttpClient return Urllib3HttpClient(self.cert, self.key, self.insecure, self.proxy_configuration) elif version_info[PY_MAJOR_VERSION] == 2 and version_info[PY_MINOR_VERSION] < 7: from urllib2httpclient import Urllib2HttpClient return CurlHttpClient(self.cert, self.key, self.insecure, self.proxy_configuration) elif version_info[PY_MAJOR_VERSION] == 2 and version_info[PY_MINOR_VERSION] <= 7 and version_info[ PY_MICRO_VERSION] < 9: from urllib2httpclient import Urllib2HttpClient return CurlHttpClient(self.cert, self.key, self.insecure, self.proxy_configuration) else: from urllib2httpclient import Urllib2HttpClient return Urllib2HttpClient(self.cert, self.key, self.insecure, self.proxy_configuration) ================================================ FILE: DSC/manifest.xml ================================================ Microsoft.OSTCExtensions DSCForLinux 2.71.1.0 VmRole Microsoft Azure DSC Extension for Linux Virtual Machines true https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx https://github.com/Azure/azure-linux-extensions true Linux Microsoft ================================================ FILE: DSC/serializerfactory.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """Serializer factory.""" PY_MAJOR_VERSION = 0 PY_MINOR_VERSION = 1 def get_serializer(version_info): """Returns the appropriate serializer module based on version_info. Python 2.6 and 2.6+ have the json module built-in, other version (2.6-) have to rely on the ancestral implementation (simplejson) which is included under the worker package. An instance of simplejson module if the installed Python version is below 2.6 An instance of json module if the installed Python version is or is above 2.6 Args: version_info : array, the build-in python version_info Returns: Json module """ if version_info[PY_MAJOR_VERSION] == 2 and version_info[PY_MINOR_VERSION] < 6: import simplejson as json else: import json return json ================================================ FILE: DSC/subprocessfactory.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """Process factory which returns a process enforcing the preexec_fn.""" try: import ctypes # See : http://man7.org/linux/man-pages/man2/prctl.2.html # See : http://lxr.free-electrons.com/source/include/uapi/linux/prctl.h libc = ctypes.CDLL("libc.so.6") PR_SET_PDEATHSIG = 1 def set_process_death_signal(death_signal): libc.prctl(PR_SET_PDEATHSIG, death_signal) except ImportError: # TODO(dalbe): Trace pass except: # For test to run on windows # TODO(dalbe): Trace pass import os import signal import subprocess import sys CTYPES_MODULE_NAME = "ctypes" def create_subprocess(cmd, env=None, stdout=None, stderr=None, cwd=None): """Creates a process forcing and sets the SIGTERM signal handler using Ctypes (when available). Else creates a process based on the pipe_output argument. Ctypes are only available in 2.5+ so processes create in python 2.4 won't die if their parent process dies. Args: cmd : string , the cmd to execute. env : dictonary(string) , the process level environment variable. stdout : boolean , sets the stdout to subprocess.PIPE when True, else stdout is left untouched. stderr : boolean , sets the stderr to subprocess.PIPE when True, else stdout is left untouched. Returns: The process object. """ if CTYPES_MODULE_NAME not in sys.modules or os.name.lower() == "nt": return subprocess.Popen(cmd, env=env, stdout=stdout, stderr=stderr, cwd=cwd) else: return subprocess.Popen(cmd, env=env, stdout=stdout, stderr=stderr, cwd=cwd, preexec_fn=set_process_death_signal(signal.SIGTERM)) ================================================ FILE: DSC/test/MockUtil.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class MockUtil: def __init__(self, test): self.test = test def get_log_dir(self): return "/tmp" def log(self, msg): print(msg) def error(self, msg): print(msg) def get_seq_no(self): return "0" def do_parse_context(self, operation): return "0" def do_status_report(self, operation, status, status_code, message): self.test.assertNotEqual(None, message) def do_exit(self,exit_code,operation,status,code,message): self.test.assertNotEqual(None, message) ================================================ FILE: DSC/test/env.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os #append installer directory to sys.path root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root) manifestFile = os.path.join(root, 'HandlerManifest.json') if os.path.exists(manifestFile): import json jsonData = open(manifestFile) manifest = json.load(jsonData) jsonData.close() extName="{0}-{1}".format(manifest[0]["name"], manifest[0]["version"]) print("Start test: %s" % extName) extDir=os.path.join("/var/lib/waagent", extName) if os.path.isdir(extDir): os.chdir(extDir) print("Switching to dir: %s" % os.getcwd()) ================================================ FILE: DSC/test/mof/azureautomation.df.meta.mof ================================================ instance of MSFT_WebDownloadManager as $MSFT_WebDownloadManager1ref { ResourceID = "[ConfigurationRepositoryWeb]AzureAutomationDSC"; SourceInfo = "C:\\OaaS-RegistrationMetaConfig2.ps1::20::9::ConfigurationRepositoryWeb"; RegistrationKey = "TsyfxalOa7P4lNWIqAVrEWhdiRNGx+2A2WYZEE1wR+lXH5snJojB9pONu79iWZVeviC/sPylSGZQlVsmCUPGOQ=="; ServerURL = "https://oaasagentsvcdf.test.azure-automation.net/accounts/a654020d-4757-41cd-bbf2-528ef2fefacb"; }; instance of MSFT_WebResourceManager as $MSFT_WebResourceManager1ref { SourceInfo = "C:\\OaaS-RegistrationMetaConfig2.ps1::27::9::ResourceRepositoryWeb"; ServerURL = "https://oaasagentsvcdf.test.azure-automation.net/accounts/a654020d-4757-41cd-bbf2-528ef2fefacb"; ResourceID = "[ResourceRepositoryWeb]AzureAutomationDSC"; RegistrationKey = "TsyfxalOa7P4lNWIqAVrEWhdiRNGx+2A2WYZEE1wR+lXH5snJojB9pONu79iWZVeviC/sPylSGZQlVsmCUPGOQ=="; }; instance of MSFT_WebReportManager as $MSFT_WebReportManager1ref { SourceInfo = "C:\\OaaS-RegistrationMetaConfig2.ps1::34::9::ReportServerWeb"; ServerURL = "https://oaasagentsvcdf.test.azure-automation.net/accounts/a654020d-4757-41cd-bbf2-528ef2fefacb"; ResourceID = "[ReportServerWeb]AzureAutomationDSC"; RegistrationKey = "TsyfxalOa7P4lNWIqAVrEWhdiRNGx+2A2WYZEE1wR+lXH5snJojB9pONu79iWZVeviC/sPylSGZQlVsmCUPGOQ=="; }; instance of MSFT_DSCMetaConfiguration as $MSFT_DSCMetaConfiguration1ref { RefreshMode = "Pull"; AllowModuleOverwrite = False; ActionAfterReboot = "ContinueConfiguration"; RefreshFrequencyMins = 30; RebootNodeIfNeeded = False; ConfigurationModeFrequencyMins = 15; ConfigurationMode = "ApplyAndMonitor"; ResourceModuleManagers = { $MSFT_WebResourceManager1ref }; ReportManagers = { $MSFT_WebReportManager1ref }; ConfigurationDownloadManagers = { $MSFT_WebDownloadManager1ref }; }; instance of OMI_ConfigurationDocument { Version="2.0.0"; MinimumCompatibleVersion = "2.0.0"; CompatibleVersionAdditionalProperties= { "MSFT_DSCMetaConfiguration:StatusRetentionTimeInDays" }; Author="azureautomation"; GenerationDate="04/17/2015 11:41:09"; GenerationHost="azureautomation-01"; Name="RegistrationMetaConfig"; }; ================================================ FILE: DSC/test/status/0.status ================================================ [{ "version": 1.0, "timestampUTC": "", "status" : { "name": "", "operation": "", "configurationAppliedTime": "", "status": "", "code": 0, "message": { "id": "id of the localized resource", "params": [ "MyParam0", "MyParam1" ] }, "formattedMessage": { "lang": "Lang[-locale]", "message": "formatted user message" }, "substatus": [{ "name": "", "status": "", "code": 0, "Message": { "id": "id of the localized resource", "params": [ "MyParam0", "MyParam1" ] }, "FormattedMessage": { "Lang": "Lang[-locale]", "Message": "formatted user message" } }] } }] ================================================ FILE: DSC/test/test_apply_meta_mof.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os import platform from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestApplyMof(unittest.TestCase): def test_apply_mof(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.mof') self.assertTrue('ReturnValue=0' in config) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_apply_mof.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os import platform from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestApplyMof(unittest.TestCase): def test_apply_mof(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.push.mof') dsc.apply_dsc_configuration('mof/localhost.nxFile.mof') self.assertTrue(os.path.exists('/tmp/dsctest')) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_compare_pkg_version.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os import platform from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class Dummy(object): pass class CompareRPMPackageVersions(unittest.TestCase): def test_with_equal_version(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = Dummy() dsc.hutil.log = waagent.Log output = dsc.compare_pkg_version('1.1.1.294', 1, 1, 1, 294) self.assertEqual(1, output) def test_with_higher_version(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = Dummy() dsc.hutil.log = waagent.Log output = dsc.compare_pkg_version('1.2.0.35', 1, 1, 1, 294) self.assertEqual(1, output) def test_with_lower_version(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = Dummy() dsc.hutil.log = waagent.Log output = dsc.compare_pkg_version('1.0.4.35', 1, 1, 1, 294) self.assertEqual(0, output) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_download_file.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestDownloadFile(unittest.TestCase): def test_download_file(self): dsc.hutil = MockUtil(self) dsc.download_external_file('https://raw.githubusercontent.com/balukambala/azure-linux-extensions/master/DSC/test/mof/dscnode.nxFile.meta.mof', '/tmp') if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_node_extension_properties.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os import json from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestNodeExtensionProperties(unittest.TestCase): def test_properties_for_pull(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.mof') self.assertTrue('ReturnValue=0' in config) content = dsc.construct_node_extension_properties(config, "upgrade") data = json.dumps(content) self.assertTrue('OMSCloudId' in data, "OMSCLoudID doesn't exist") #self.assertTrue('ExtHandlerVersion' in extensionInformation, "ExtHandlerVersion doesn't exist") #self.assertEqual('Microsoft.OSTCExtensions.DSCForLinux', extensionInformation['ExtHandlerName']) def test_send_request_to_pullserver(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/azureautomation.df.meta.mof') self.assertTrue('ReturnValue=0' in config) response = dsc.send_heart_beat_msg_to_agent_service("install") self.assertEqual(response.status_code, 200) def test_push_request_properties(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.push.mof') self.assertTrue('ReturnValue=0' in config) response = dsc.send_heart_beat_msg_to_agent_service("install") self.assertIsNone(response) def test_update_node_properties(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) response = dsc.update() self.assertIsNone(response) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_register.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestRegister(unittest.TestCase): def test_register_without_registration_info(self): print "Register test case with invalid Registration url and Registration key" dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() exit_code, output = dsc.register_automation('','','','','','') self.assertEqual(exit_code, 51) def test_register_invalid_configuration_mode(self): print "Register test case with invalid configuration mode" dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() exit_code, output = dsc.register_automation('somekey','http://dummy','','','','some') self.assertEqual(exit_code, 51) def test_register(self): print "Register test case with valid parameters" dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() exit_code, output = dsc.register_automation('somekey','http://dummy','test.localhost.mof','15','30','applyandmonitor') self.assertEqual(exit_code, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/test/test_status_update.py ================================================ #!/usr/bin/env python # # DSC Extension For Linux # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import dsc import os import json from Utils.WAAgentUtil import waagent from MockUtil import MockUtil waagent.LoggerInit('/tmp/test.log','/dev/null') class TestStatusUpdate(unittest.TestCase): def verify_nodeid_vmuuid(self, status_file): self.assertTrue(os.path.exists(status_file), "file exists") if os.path.exists(status_file): jsonData = open(status_file) status_data = json.load(jsonData)[0] self.assertTrue('status' in status_data, "status doesn't exists") substatusArray = status_data['status']['substatus'] isMetaDataFound = False metasubstatus = None if 'metadata' in substatusArray[0].viewvalues(): metasubstatus = substatusArray[0] self.assertTrue('formattedMessage' in metasubstatus, "formattedMessage doesn't exists") formatedMessage = metasubstatus['formattedMessage'] self.assertTrue('message' in formatedMessage, "message doesn't exists") self.assertTrue('AgentID' in formatedMessage['message'], "AgentID doesn't exists") def test_vmuuid(self): dsc.hutil = MockUtil(self) vmuuid = dsc.get_vmuuid() self.assertTrue(vmuuid is not None, "vm uuid is none") def test_nodeid_with_dsc(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.push.mof') nodeid = dsc.get_nodeid('/etc/opt/omi/conf/omsconfig/agentid') self.assertTrue(nodeid is not None, "nodeid is none") def test_nodeid_without_dsc(self): dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) nodeid = dsc.get_nodeid('/etc/opt/omi/conf/omsconfig/agentid1') self.assertTrue(nodeid is None, "nodeid is not none") def test_statusfile_update(self): status_file = 'status/0.status' dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) class cresponse: deserialized_data= { "AccountName" : "test", "ResourceGroupName" : "rgName", "SubscriptionId" : "testsubid" } dsc.update_statusfile(status_file, '123','345', cresponse) self.verify_nodeid_vmuuid(status_file) def test_is_statusfile_update_idempotent(self): status_file = 'status/0.status' dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) class cresponse: deserialized_data= { "AccountName" : "test", "ResourceGroupName" : "rgName", "SubscriptionId" : "testsubid" } dsc.update_statusfile(status_file, '123','345', cresponse) dsc.update_statusfile(status_file, '123','345', cresponse) self.verify_nodeid_vmuuid(status_file) def test_is_statusfile_update_register(self): status_file = 'status/0.status' dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() exit_code, output = dsc.register_automation('somekey','http://dummy','','','','') self.verify_nodeid_vmuuid(status_file) def test_is_statusfile_update_pull(self): status_file = 'status/0.status' dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.mof') self.assertTrue('ReturnValue=0' in config) self.verify_nodeid_vmuuid(status_file) def test_is_statusfile_update_push(self): status_file = 'status/0.status' dsc.distro_category = dsc.get_distro_category() dsc.hutil = MockUtil(self) dsc.install_dsc_packages() dsc.start_omiservice() config = dsc.apply_dsc_meta_configuration('mof/dscnode.nxFile.meta.push.mof') dsc.apply_dsc_configuration('mof/localhost.nxFile.mof') self.assertTrue(os.path.exists('/tmp/dsctest')) self.verify_nodeid_vmuuid(status_file) if __name__ == '__main__': unittest.main() ================================================ FILE: DSC/urllib2httpclient.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """Urllib2 HttpClient.""" try: from http.client import HTTPSConnection except ImportError: from httplib import HTTPSConnection import socket import time import traceback import sys try: from urllib.parse import urlparse, urlencode from urllib.request import urlopen, Request, HTTPSHandler, build_opener, ProxyHandler from urllib.error import HTTPError except ImportError: from urlparse import urlparse from urllib import urlencode from urllib2 import urlopen, Request, HTTPError, HTTPSHandler, build_opener, ProxyHandler from httpclient import * PY_MAJOR_VERSION = 0 PY_MINOR_VERSION = 1 PY_MICRO_VERSION = 2 SSL_MODULE_NAME = "ssl" # On some system the ssl module might be missing try: import ssl except ImportError: ssl = None class HttpsClientHandler(HTTPSHandler): """Https handler to enable attaching cert/key to request. Also used to disable strict cert verification for testing. """ def __init__(self, cert_path, key_path, insecure=False): self.cert_path = cert_path self.key_path = key_path ssl_context = None if insecure and SSL_MODULE_NAME in sys.modules and (sys.version_info[PY_MAJOR_VERSION] == 2 and sys.version_info[PY_MINOR_VERSION] >= 7 and sys.version_info[PY_MICRO_VERSION] >= 9): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE HTTPSHandler.__init__(self, context=ssl_context) # Context can be None here def https_open(self, req): return self.do_open(self.get_https_connection, req, context=self._context) def get_https_connection(self, host, context=None, timeout=180): """urllib2's AbstractHttpHandler will invoke this method with the host/timeout parameter. See urllib2's AbstractHttpHandler for more details. Args: host : string , the host. context : ssl_context , the ssl context. timeout : int , the timeout value in seconds. Returns: An HttpsConnection """ socket.setdefaulttimeout(180) if self.cert_path is None or self.key_path is None: return HTTPSConnection(host, timeout=timeout, context=context) else: return HTTPSConnection(host, cert_file=self.cert_path, key_file=self.key_path, timeout=timeout, context=context) def request_retry_handler(func): def decorated_func(*args, **kwargs): max_retry_count = 3 for iteration in range(0, max_retry_count, 1): try: ret = func(*args, **kwargs) return ret except Exception as exception: if iteration >= max_retry_count - 1: raise RetryAttemptExceededException(traceback.format_exc()) elif SSL_MODULE_NAME in sys.modules: if type(exception).__name__ == 'SSLError': time.sleep(5 + iteration) continue raise exception return decorated_func class Urllib2HttpClient(HttpClient): """Urllib2 http client. Inherits from HttpClient. Targets: [2.7.9 - 2.7.9+] only due to the lack of strict certificate verification prior to this version. Implements the following method common to all classes inheriting HttpClient. get (url, headers) post (url, headers, data) """ def __init__(self, cert_path, key_path, insecure=False, proxy_configuration=None): HttpClient.__init__(self, cert_path, key_path, insecure, proxy_configuration) @request_retry_handler def issue_request(self, url, headers, method=None, data=None): """Issues a GET request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : string , contains the serialized request body. Returns: A RequestResponse :param method: """ https_handler = HttpsClientHandler(self.cert_path, self.key_path, self.insecure) opener = build_opener(https_handler) if self.proxy_configuration is not None: proxy_handler = ProxyHandler({'http': self.proxy_configuration, 'https': self.proxy_configuration}) opener.add_handler(proxy_handler) req = Request(url, data=data, headers=headers) req.get_method = lambda: method response = opener.open(req, timeout=30) opener.close() https_handler.close() return response def get(self, url, headers=None): """Issues a GET request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. Returns: An http_response """ headers = self.merge_headers(self.default_headers, headers) try: response = self.issue_request(url, headers=headers, method=self.GET) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) def post(self, url, headers=None, data=None): """Issues a POST request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.POST, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) def put(self, url, headers=None, data=None): """Issues a PUT request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.PUT, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) def delete(self, url, headers=None, data=None): """Issues a DELETE request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.DELETE, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) ================================================ FILE: DSC/urllib3httpclient.py ================================================ #!/usr/bin/env python2 # # Copyright (C) Microsoft Corporation, All rights reserved. """Urllib2 HttpClient.""" try: from http.client import HTTPSConnection except ImportError: from httplib import HTTPSConnection import socket import time import traceback import sys try: from urllib.parse import urlparse, urlencode from urllib.request import urlopen, Request, HTTPSHandler, build_opener, ProxyHandler from urllib.error import HTTPError except ImportError: from urlparse import urlparse from urllib import urlencode from urllib2 import urlopen, Request, HTTPError, HTTPSHandler, build_opener, ProxyHandler from httpclient import * PY_MAJOR_VERSION = 0 PY_MINOR_VERSION = 1 PY_MICRO_VERSION = 2 SSL_MODULE_NAME = "ssl" # On some system the ssl module might be missing try: import ssl except ImportError: ssl = None class HttpsClientHandler(HTTPSHandler): """Https handler to enable attaching cert/key to request. Also used to disable strict cert verification for testing. """ def __init__(self, cert_path, key_path, insecure=False): self.cert_path = cert_path self.key_path = key_path ssl_context = None if insecure and SSL_MODULE_NAME in sys.modules and (sys.version_info[PY_MAJOR_VERSION] == 2 and sys.version_info[PY_MINOR_VERSION] >= 7 and sys.version_info[PY_MICRO_VERSION] >= 9): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE HTTPSHandler.__init__(self, context=ssl_context) # Context can be None here def https_open(self, req): return self.do_open(self.get_https_connection, req, context=self._context) def get_https_connection(self, host, context=None, timeout=180): """urllib2's AbstractHttpHandler will invoke this method with the host/timeout parameter. See urllib2's AbstractHttpHandler for more details. Args: host : string , the host. context : ssl_context , the ssl context. timeout : int , the timeout value in seconds. Returns: An HttpsConnection """ socket.setdefaulttimeout(180) if self.cert_path is None or self.key_path is None: return HTTPSConnection(host, timeout=timeout, context=context) else: return HTTPSConnection(host, cert_file=self.cert_path, key_file=self.key_path, timeout=timeout, context=context) def request_retry_handler(func): def decorated_func(*args, **kwargs): max_retry_count = 3 for iteration in range(0, max_retry_count, 1): try: ret = func(*args, **kwargs) return ret except Exception as exception: if iteration >= max_retry_count - 1: raise RetryAttemptExceededException(traceback.format_exc()) elif SSL_MODULE_NAME in sys.modules: if type(exception).__name__ == 'SSLError': time.sleep(5 + iteration) continue raise exception return decorated_func class Urllib3HttpClient(HttpClient): """Urllib2 http client. Inherits from HttpClient. Targets: [2.7.9 - 2.7.9+] only due to the lack of strict certificate verification prior to this version. Implements the following method common to all classes inheriting HttpClient. get (url, headers) post (url, headers, data) """ def __init__(self, cert_path, key_path, insecure=False, proxy_configuration=None): HttpClient.__init__(self, cert_path, key_path, insecure, proxy_configuration) @request_retry_handler def issue_request(self, url, headers, method=None, data=None): """Issues a GET request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : string , contains the serialized request body. Returns: A RequestResponse :param method: """ https_handler = HttpsClientHandler(self.cert_path, self.key_path, self.insecure) opener = build_opener(https_handler) if self.proxy_configuration is not None: proxy_handler = ProxyHandler({'http': self.proxy_configuration, 'https': self.proxy_configuration}) opener.add_handler(proxy_handler) if sys.version_info >= (3,0): if data is not None: data = data.encode("utf-8") req = Request(url, data=data, headers=headers) req.get_method = lambda: method response = opener.open(req, timeout=30) opener.close() https_handler.close() return response def get(self, url, headers=None): """Issues a GET request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. Returns: An http_response """ headers = self.merge_headers(self.default_headers, headers) try: response = self.issue_request(url, headers=headers, method=self.GET) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) def post(self, url, headers=None, data=None): """Issues a POST request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.POST, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read().decode('utf-8')) def put(self, url, headers=None, data=None): """Issues a PUT request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.PUT, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read().decode('utf-8')) def delete(self, url, headers=None, data=None): """Issues a DELETE request to the provided url and using the provided headers. Args: url : string , the url. headers : dictionary, contains the headers key value pair. data : dictionary, contains the non-serialized request body. Returns: A RequestResponse """ headers = self.merge_headers(self.default_headers, headers) if data is None: serial_data = "" else: serial_data = self.json.dumps(data) headers.update({self.CONTENT_TYPE_HEADER_KEY: self.APP_JSON_HEADER_VALUE}) try: response = self.issue_request(url, headers=headers, method=self.DELETE, data=serial_data) except HTTPError: exception_type, error = sys.exc_info()[:2] return RequestResponse(error.code) return RequestResponse(response.getcode(), response.read()) ================================================ FILE: Diagnostic/ChangeLogs ================================================ 2020-11-06: LAD-3.0.131 - Fix issue #1262: Crashing bug caused by task synchronization issue in XJsonBlobRequest 2020-07-01: LAD-3.0.129 - Fix issue #499 : Need a better error message if LAD protectedSettings is missing - Fix issue #944 : Allow installing LAD without storage account sink 2020-01-30: LAD-3.0.127 - Fix issue #996: Remove fluent-gem-plugin from LAD - Fix issue #994: Move LAD's out_mdsd buffer path to own directory - Fix issue #948: Failed to launch mdsd with error: cannot concatenate 'str' and 'int' objects - Fix issue #978: LAD is limited to upload 2 events per second to EventHubs 2019-10-24: LAD-3.0.125 - Reinstall OMI if it is failing to start as recovery action - VM extension config update needs to regenerate the config artifacts 2019-08-14: LAD-3.0.123 - Fix a race condition in install. The dependencies were installed during "enable" step; which is not idempotent. 2019-06-17: LAD-3.0.121 - Add logrotate policy to manage mdsd log files. 2019-01-15: LAD-3.0.119 - Add blobEndpoint for storage accounts; bug fix for National clouds. 2017-09-05: LAD-3.0.111 - Ensure SAS storage token is supplied - Explicitly reject deprecated use of LAD 2.3's storageAccountKey 2017-08-11: LAD-3.0.109 - Fix waagent-related issue on Debian distros - Add additional unit tests - Replace multiple uses of "local" with "locale" - Improve error reporting when catching an exception - Fix #398, #399, #340 2017-05-16: LAD-3.0.107 - Move resourceId field generation for JSON events from LAD to mdsd 2017-05-10: LAD-3.0.103 - Allow '*' in syslog spec, add more fields in syslog records for EventHubs 2017-05-08: LAD-3.0.101 - New release of LAD 3.0. Refer to README.md 2017-01-13: LAD-2.3.9021 - Fix rsyslogd core dump issue when re-enabling the extension - Take latest mdsd binary that fixes other issues like missing perf counter logs when there's a race condition between mdsd and omiserver. 2016-11-30: LAD-2.3.9017 - Fix scx upgrade issue on RPM-based distros when apache or mysql is installed. 2016-11-11: LAD-2.3.9015 - Correctly fail Enable when mdsd dependency set up fails. - Added /etc/fstab watcher feature (logging to /dev/console so that issues can be found on serial logs) - Add storage account SAS token support (replacing storage account key) - Encrypt storage secret in xmlCfg.xml 2016-10-31: LAD-2.3.9013 - Use semodule -u (upgrade) to reduce unnecessary SELinux policy re-install time - Use the latest scx package version (1.6.2-337) - Issue #265: Don't remove port 1270 from omiserver.conf if omsagent is installed. 2016-10-07: LAD-2.3.9011 - Update OpenSSL library to the latest - Update rsyslog output modules for all versions of rsyslog (5/7/8) to use Unix domain socket. - Update mdsd binary to the latest (1.2.104) with various fixes - Dependencies are now installed at Enable time, to reduce VM deployment time. 2016-09-16: LAD-2.3.9009 - Underlying monitoring agent binary (mdsd) upgrade with many fixes and improvements - Fixes storage end point bug (affected Mooncake and Blackforest) 2016-07-14: LAD-2.3.9007 - Fixes install issues on some RH-based distros (e.g., OracleLinux 7) due to lack of tar. - Fixes duplicate logging (on /var/log/syslog) issue on fileCfg 2016-06-30: LAD-2.3.9005 - Fixes non-starting monitoring agent issue on systemd-enabled distros (#180) - HandlerUtil unified with other extensions - Telemetry (logging) improvement - Remove possibility of logging some password 2016-06-21: LAD-2.3.9003 - Monitoring agent (mdsd) updates for a memory issue fix, a signal handler fix, and a fix to avoid a spin loop under certain circumstances - doesn't count non-quick crashes (>30 mins) towards retry limit - OMI reconfiguration not to listen to port 1270 - Use systemd on Ubuntu 16.04 as well - Validate mdsd XML config before starting mdsd, fails fast on invalid config (with success) - Small Python 2.6 bug fix (syslog.openlog()) 2016-06-06: LAD-2.3.9001 - Fix issue of syslog messages not collected by default on SLES 11 - Minor config syntactic fixes - Logging fix to show correct extension version - Monitoring agent kill is no longer SIGKILL, but SIGTERM. - Monitoring agent listening port is now dynamic if the specified port (29131) is in use. - Monitoring agent core dump is enabled (dumped on its current working directory) - Newer monitoring agent bits with added features (not available on LAD yet) 2016-05-04: LAD-2.3.9 - mdsd bits are now built as statically as possible, so that a single monolithic executable can be used on as many distros/versions. - OMI install result is checked and tried up to 3 times. If all fail, LAD install fails as well. - OMI is checked periodically for its health and LAD will restart it if OMI is not up. 2016-03-26: LAD-2.3.7 - mdsd http proxy config through waagent.conf - OpenSUSE 13 support revival - LAD no longer (re)starts apache/mysql invasively (restarts only when they were running) - Bundle libglibmm*.so (no longer downloaded/installed when LAD is installed) - AppInsights configuration changes 2016-03-08: LAD-2.3.6 - mdsd http proxy support (mdsd binary change) - Ubuntu 16.04 glibmm install issue fix - Report success extension event for unsupported distros/versions 2016-02-25: LAD-2.3.5. Reviving SUSE 11 support and consolidating binaries of diff versions of same distro 2016-02-25: LAD-2.3.4. Hotfix for portal perf graphs not showing (xmlCfg parsing bug) 2016-02-15: LAD-2.3.3. No changes on mdsd/LAD code. Just rebuilding to take in the most recent AISDK fixes ================================================ FILE: Diagnostic/DistroSpecific.py ================================================ #!/usr/bin/env python # # Azure Linux extension # Distribution-specific actions # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import exceptions import time import subprocess import re from Utils.WAAgentUtil import waagent class CommonActions: def __init__(self, logger): self.logger = logger def filterNonAsciiCharacters(self, output_msg): return output_msg.encode('utf-8').decode('ascii','ignore') def log_run_get_output(self, cmd, should_log=True): """ Execute a command in a subshell :param str cmd: The command to be executed :param bool should_log: If true, log command execution :rtype: int, str :return: A tuple of (subshell exit code, contents of stdout) """ if should_log: self.logger("RunCmd " + cmd) error, msg = waagent.RunGetOutput(cmd, chk_err=should_log) if should_log: self.logger("Return " + str(error) + ":" + msg) return int(error), self.filterNonAsciiCharacters(msg) def log_run_ignore_output(self, cmd, should_log=True): """ Execute a command in a subshell :param str cmd: The command to be executed :param bool should_log: True if command execution should be logged. (False preserves privacy of parameters.) :rtype: int :return: The subshell exit code """ error, msg = self.log_run_get_output(cmd, should_log) return int(error) def log_run_with_timeout(self, cmd, timeout=3600): """ Execute a command in a subshell, killing the subshell if it runs too long :param str cmd: The command to be executed :param int timeout: The maximum elapsed time, in seconds, to wait for the subshell to return; default 360 :rtype: int, str :return: (1, "Process timeout\n") if timeout, else (subshell exit code, contents of stdout) """ self.logger("Run with timeout: " + cmd) process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True, executable='/bin/bash') time.sleep(1) while process.poll() is None and timeout > 0: time.sleep(1) timeout -= 1 if process.poll() is None: self.logger("Timeout while running:" + cmd) process.kill() return 1, "Process timeout\n" output, error = process.communicate() self.logger("Return " + str(error)) return int(process.returncode), output def log_run_multiple_cmds(self, cmds, with_timeout, timeout=360): """ Execute multiple commands in subshells, with optional timeout protection :param Iterable[str] cmds: An iterable of commands to be executed :param bool with_timeout: True if commands should be run with timeout :param int timeout: The timeout, in seconds; default 360. Ignored if with_timeout is False. :rtype: int, str :return: A tuple of (sum of status codes, concatenated stdout from commands) """ errors = 0 output = [] for cmd in cmds: if with_timeout: err, msg = self.log_run_with_timeout(cmd, timeout) else: err, msg = self.log_run_get_output(cmd) errors += err output.append(msg) return errors, ''.join(output) def extract_om_path_and_version(self, results): """ Get information about rsyslogd :param str results: Package information about omprog.so or version :rtype: str, str :return: (Path where rsyslogd output modules are installed, major version of rsyslogd) """ match = re.search(r"(.+)omprog\.so", results) if not match: return None, '' path = match.group(1) match = re.search(r"Version\s*:\s*(\d+)\D", results) if not match: self.logger("rsyslog is present but version could not be determined") return None, '' version = match.group(1) return path, version def install_extra_packages(self, packages, with_timeout=False): """ Ensure an arbitrary set of packages is installed :param list[str] packages: Iterable of package names :param bool with_timeout: true if package installations should be aborted if they take too long :rtype: int :return: """ return 0, '' def install_required_packages(self): """ Install packages required by this distro to meet the common bar required of all distros :rtype: int, str :return: (status, concatenated stdout from all package installs) """ return 0, "no additional packages were needed" def is_package_handler(self, package_manager): """ Checks if the distro's package manager matches the specified tool. :param str package_manager: The tool to be checked against the distro's native package manager :rtype: bool :return: True if the distro's native package manager is package_manager """ return False def prepare_for_mdsd_install(self): return 0, '' def extend_environment(self, env): """ Add required environment variables to process environment :param dict[str, str] env: Process environment """ pass def use_systemd(self): """ Determine if the distro uses systemd as its system management tool. :rtype: bool :return: True if the distro uses systemd as its system management tool. """ return False def install_lad_mdsd(self): """ Install the mdsd binary using the bundled .deb/.rpm packages. Should be overridden by each direct subclass for Debian/Redhat. Can't be called for this base class. :rtype: int, str :return: (status, concatenated stdout from the package install) """ assert False, "Can't be called on the base class (CommonActions)!" def remove_lad_mdsd(self): """ Remove the mdsd binary that was installed with the bundled .deb/.rpm packages. Should be overridden by each direct subclass for Debian/Redhat. Can't be called for this base class. :rtype: int, str :return: (status, concatenated stdout from the package remove) """ assert False, "Can't be called on the base class (CommonActions)!" class DebianActions(CommonActions): def __init__(self, logger): CommonActions.__init__(self, logger) def is_package_handler(self, package_manager): return package_manager == "dpkg" def install_extra_packages(self, packages, with_timeout=False): cmd = 'dpkg-query -l PACKAGE |grep ^ii; if [ ! $? == 0 ]; then apt-get update; apt-get install -y PACKAGE; fi' return self.log_run_multiple_cmds([cmd.replace("PACKAGE", p) for p in packages], with_timeout) def extend_environment(self, env): env.update({"SSL_CERT_DIR": "/usr/lib/ssl/certs", "SSL_CERT_FILE": "/usr/lib/ssl/cert.pem"}) def install_lad_mdsd(self): return self.log_run_get_output('dpkg -i lad-mdsd-*.deb') def remove_lad_mdsd(self): return self.log_run_get_output('dpkg -P lad-mdsd') class CredativActions(DebianActions): def __init__(self, logger): DebianActions.__init__(self, logger) def install_required_packages(self): # curl not installed by default on Credative Debian Linux, now required by omsagent return self.install_extra_packages(('curl',), True) class Ubuntu1510OrHigherActions(DebianActions): def __init__(self, logger): DebianActions.__init__(self, logger) def install_extra_packages(self, packages, with_timeout=False): count = len(packages) if count == 0: return 0, '' package_list = str.join(' ', packages) cmd = '[ $(dpkg -l PACKAGES |grep ^ii |wc -l) -eq \'COUNT\' ] || apt-get install -y PACKAGES' cmd = cmd.replace('PACKAGES', package_list).replace('COUNT', str(count)) if with_timeout: return self.log_run_with_timeout(cmd) else: return self.log_run_get_output(cmd) def use_systemd(self): return True class RedhatActions(CommonActions): def __init__(self, logger): CommonActions.__init__(self, logger) def install_extra_packages(self, packages, with_timeout=False): install_cmd = 'rpm -q PACKAGE; if [ ! $? == 0 ]; then yum install -y PACKAGE; fi' return self.log_run_multiple_cmds([install_cmd.replace("PACKAGE", p) for p in packages], with_timeout) def install_required_packages(self): # policycoreutils-python missing on Oracle Linux (still needed to manipulate SELinux policy). # tar is really missing on Oracle Linux 7! return self.install_extra_packages(('policycoreutils-python', 'tar'), True) def is_package_handler(self, package_manager): return package_manager == "rpm" def extend_environment(self, env): env.update({"SSL_CERT_DIR": "/etc/pki/tls/certs", "SSL_CERT_FILE": "/etc/pki/tls/cert.pem"}) def install_lad_mdsd(self): return self.log_run_get_output('rpm -i --force lad-mdsd-*.rpm') def remove_lad_mdsd(self): return self.log_run_get_output('rpm -e lad-mdsd') class Redhat8Actions(RedhatActions): def __init__(self, logger): RedhatActions.__init__(self, logger) def install_required_packages(self): return self.install_extra_packages(('policycoreutils-python-utils', 'tar'), True) class Suse11Actions(RedhatActions): def __init__(self, logger): RedhatActions.__init__(self, logger) self.certs_file = "/etc/ssl/certs/mdsd-ca-certs.pem" def install_extra_packages(self, packages, with_timeout=False): install_cmd = 'rpm -qi PACKAGE; if [ ! $? == 0 ]; then zypper --non-interactive install PACKAGE;fi' return self.log_run_multiple_cmds([install_cmd.replace("PACKAGE", p) for p in packages], with_timeout) def install_required_packages(self): return 0, "no additional packages were needed" # For SUSE11, we need to create a CA certs file for our statically linked OpenSSL 1.0 libs def prepare_for_mdsd_install(self): commands = ( r'cp /dev/null {0}'.format(self.certs_file), r'chown 0:0 {0}'.format(self.certs_file), r'chmod 0644 {0}'.format(self.certs_file), r"cat /etc/ssl/certs/????????.[0-9a-f] | sed '/^#/d' >> {0}".format(self.certs_file) ) return self.log_run_multiple_cmds(commands, False) def extend_environment(self, env): env.update({"SSL_CERT_FILE": self.certs_file}) class Suse12Actions(RedhatActions): def __init__(self, logger): RedhatActions.__init__(self, logger) def install_extra_packages(self, packages, with_timeout=False): install_cmd = 'rpm -qi PACKAGE; if [ ! $? == 0 ]; then zypper --non-interactive install PACKAGE;fi' return self.log_run_multiple_cmds([install_cmd.replace("PACKAGE", p) for p in packages], with_timeout) def install_required_packages(self): return self.install_extra_packages(('libgthread-2_0-0', 'ca-certificates-mozilla', 'rsyslog'), True) def extend_environment(self, env): env.update({"SSL_CERT_DIR": "/var/lib/ca-certificates/openssl", "SSL_CERT_FILE": "/etc/ssl/cert.pem"}) class CentosActions(RedhatActions): def __init__(self, logger): RedhatActions.__init__(self, logger) def install_extra_packages(self, packages, with_timeout=False): install_cmd = 'rpm -qi PACKAGE; if [ ! $? == 0 ]; then yum install -y PACKAGE; fi' return self.log_run_multiple_cmds([install_cmd.replace("PACKAGE", p) for p in packages], with_timeout) def install_required_packages(self): # policycoreutils-python missing on CentOS (still needed to manipulate SELinux policy) return self.install_extra_packages(('policycoreutils-python',), True) class Centos8Actions(RedhatActions): def __init__(self, logger): RedhatActions.__init__(self, logger) def install_required_packages(self): return self.install_extra_packages(('policycoreutils-python-utils', 'tar'), True) DistroMap = { 'debian': CredativActions, # Credative Debian Linux took the 'debian' platform name with the curl deficiency, # when all other Debian-based distros have curl, so is this strange mapping... 'kali': DebianActions, 'ubuntu': DebianActions, 'ubuntu:16.04': Ubuntu1510OrHigherActions, 'ubuntu:18.04': Ubuntu1510OrHigherActions, 'redhat': RedhatActions, 'redhat:8': Redhat8Actions, 'centos': CentosActions, 'centos:8':Centos8Actions, 'oracle': RedhatActions, 'suse:12': Suse12Actions, 'suse': Suse12Actions, 'sles:15': Suse12Actions, 'opensuse:15':Suse12Actions, 'almalinux':Redhat8Actions } def get_distro_actions(name, version, logger): name_and_version = name + ":" + version if name_and_version in DistroMap: return DistroMap[name_and_version](logger) else: major_version = version.split(".")[0] name_and_major_version = name + ":" + major_version if name_and_major_version in DistroMap: return DistroMap[name_and_major_version](logger) if name in DistroMap: return DistroMap[name](logger) raise exceptions.LookupError('{0} is not a supported distro'.format(name_and_version)) ================================================ FILE: Diagnostic/HandlerManifest.json ================================================ [ { "version": 1.0, "handlerManifest": { "disableCommand": "shim.sh -disable", "enableCommand": "shim.sh -enable", "installCommand": "shim.sh -install", "uninstallCommand": "shim.sh -uninstall", "updateCommand": "shim.sh -update", "rebootAfterInstall": false, "reportHeartbeat": false, "updateMode": "updatewithinstall" } } ] ================================================ FILE: Diagnostic/Makefile ================================================ all: package .PHONY: all .PHONY: clean .PHONY: package LADSOURCES = \ diagnostic.py \ watcherutil.py \ tests \ HandlerManifest.json \ license.txt \ manifest.xml \ run_unittests.sh \ services \ Utils \ UTILSOURCES = \ ../Utils/HandlerUtil.py \ ../Utils/__init__.py \ ../Utils/WAAgentUtil.py \ clean: rm -rf output package: $(LADSOURCES) $(UTILSOURCES) mkdir -p output cp -t output -r $(LADSOURCES) cp -t output/Utils -r $(UTILSOURCES) ================================================ FILE: Diagnostic/Providers/Builtin.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of # the Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # A provider is responsible for taking a particular syntax of configuration instructions, as found in the JSON config # blob, and using it to enable collection of data as specified in those instructions. # The "Builtin" configuration instructions are agnostic to the collection mechanism used to implement them; it's simply # a list of metrics to be collected on a particular schedule. The metric names are collected into classes for ease # of understanding by the user. The predefined classes and metric names are available without regard to how the # underlying mechanism might name them. # # This specific implementation of the Builtin provider converts the configuration instructions into a set of OMI # queries to be executed by the mdsd agent. The agent executes the queries are written by this provider and uploads # the results to the appropriate table in the customer's storage account. # # A different implementation might use fluentd to collect the data and to upload the results to table storage. # A different provider (e.g. an OMI provider) would expect configuration instructions bound directly to OMI; that is, # the PublicConfig JSON delivered to LAD would itself contain actual OMI queries. The implementation of such a provider # might construct an mdsd configuration file cause mdsd to run the specified queries and store the data in tables. import Utils.ProviderUtil as ProvUtil from collections import defaultdict import xml.etree.ElementTree as ET import Utils.XmlUtil as XmlUtil from xml.sax.saxutils import quoteattr # These are the built-in metrics this code provides, grouped by class. The builtin countername space is # case insensitive; this collection of maps converts to the case-sensitive OMI name. _builtIns = { 'processor': { 'percentidletime': 'PercentIdleTime', 'percentprocessortime': 'PercentProcessorTime', 'percentiowaittime': 'PercentIOWaitTime', 'percentinterrupttime': 'PercentInterruptTime', 'percentusertime': 'PercentUserTime', 'percentnicetime': 'PercentNiceTime', 'percentprivilegedtime': 'PercentPrivilegedTime' }, 'memory': { 'availablememory': 'AvailableMemory', 'percentavailablememory': 'PercentAvailableMemory', 'usedmemory': 'UsedMemory', 'percentusedmemory': 'PercentUsedMemory', 'pagespersec': 'PagesPerSec', 'pagesreadpersec': 'PagesReadPerSec', 'pageswrittenpersec': 'PagesWrittenPerSec', 'availableswap': 'AvailableSwap', 'percentavailableswap': 'PercentAvailableSwap', 'usedswap': 'UsedSwap', 'percentusedswap': 'PercentUsedSwap'}, 'network': { 'bytestransmitted': 'BytesTransmitted', 'bytesreceived': 'BytesReceived', 'bytestotal': 'BytesTotal', 'packetstransmitted': 'PacketsTransmitted', 'packetsreceived': 'PacketsReceived', 'totalrxerrors': 'TotalRxErrors', 'totaltxerrors': 'TotalTxErrors', 'totalcollisions': 'TotalCollisions' }, 'filesystem': { 'freespace': 'FreeMegabytes', 'usedspace': 'UsedMegabytes', 'percentfreespace': 'PercentFreeSpace', 'percentusedspace': 'PercentUsedSpace', 'percentfreeinodes': 'PercentFreeInodes', 'percentusedinodes': 'PercentUsedInodes', 'bytesreadpersecond': 'ReadBytesPerSecond', 'byteswrittenpersecond': 'WriteBytesPerSecond', 'bytespersecond': 'BytesPerSecond', 'readspersecond': 'ReadsPerSecond', 'writespersecond': 'WritesPerSecond', 'transferspersecond': 'TransfersPerSecond' }, 'disk': { 'readspersecond': 'ReadsPerSecond', 'writespersecond': 'WritesPerSecond', 'transferspersecond': 'TransfersPerSecond', 'averagereadtime': 'AverageReadTime', 'averagewritetime': 'AverageWriteTime', 'averagetransfertime': 'AverageTransferTime', 'averagediskqueuelength': 'AverageDiskQueueLength', 'readbytespersecond': 'ReadBytesPerSecond', 'writebytespersecond': 'WriteBytesPerSecond', 'bytespersecond': 'BytesPerSecond' } } _omiClassName = { 'processor': 'SCX_ProcessorStatisticalInformation', 'memory': 'SCX_MemoryStatisticalInformation', 'network': 'SCX_EthernetPortStatistics', 'filesystem': 'SCX_FileSystemStatisticalInformation', 'disk': 'SCX_DiskDriveStatisticalInformation' } # Default CQL condition clause (WHERE ...) for relevant counter classes _defaultCqlCondition = { #'network': '...', # No 'Name' or 'IsAggregate' columns from SCX_EthernetPort... cql query. # If there are multiple NICs, this might cause some issue. Beware. # The column/value distinguishing NICs is e.g., 'InstanceID="eth0"'. 'filesystem': 'IsAggregate=TRUE', # For specific file system (e.g., root fs), use 'Name="/"' 'disk': 'IsAggregate=TRUE', # For specific disk (e.g., /dev/sda), use 'Name="sda"' 'processor': 'IsAggregate=TRUE', # For specific processor core, use 'Name="0"' #'memory': 'IsAggregate=TRUE', # No separate instances of memory, so no WHERE condition is needed } # The Azure Metrics infrastructure, along with App Insights, requires that quantities be measured # in one of these units: Percent, Count, Seconds, Milliseconds, Bytes, BytesPerSecond, CountPerSecond # # Some of the OMI metrics are retrieved in some other unit (e.g. "MiB") and need to be scaled # to the expected unit before being passed along the pipeline. The _scaling map holds all OMI counter # names that need to be scaled. If a counterSpecifier isn't in this list, no scaling is needed. _scaling = defaultdict(lambda:defaultdict(str), { 'memory' : defaultdict(str, { 'AvailableMemory': 'scaleUp="1048576"', 'UsedMemory': 'scaleUp="1048576"', 'AvailableSwap': 'scaleUp="1048576"', 'UsedSwap': 'scaleUp="1048576"' } ), 'filesystem' : defaultdict(str, {'FreeMegabytes': 'scaleUp="1048576"', 'UsedMegabytes': 'scaleUp="1048576"', }), } ) _metrics = defaultdict(list) _eventNames = {} _defaultSampleRate = 15 def SetDefaultSampleRate(rate): global _defaultSampleRate _defaultSampleRate = rate def default_condition(class_name): return _defaultCqlCondition[class_name] if class_name in _defaultCqlCondition else '' class BuiltinMetric: def __init__(self, counterSpec): """ Construct an instance of the BuiltinMetric class. Values are case-insensitive unless otherwise noted. "type": the provider type. If present, must have value "builtin". If absent, assumed to be "builtin". "class": the name of the class within which this metric is scoped. Must be a key in the _builtIns dict. "counter": the name of the metric, within the class. Must appear in the list of metric names for this class found in the _builtIns dict. In this implementation, the builtin counter name is mapped to the OMI counter name "instanceId": the identifier for the specific instance of the metric, if any. Must be "None" for uninstanced metrics. "counterSpecifier": the name under which this retrieved metric will be stored "sampleRate": a string containing an ISO8601-compliant duration. :param counterSpec: A dict containing the key/value settings that define the metric to be collected. """ t = ProvUtil.GetCounterSetting(counterSpec, 'type') if t is None: self._Type = 'builtin' else: self._Type = t.lower() if t != 'builtin': raise ProvUtil.UnexpectedCounterType('Expected type "builtin" but saw type "{0}"'.format(self._Type)) self._CounterClass = ProvUtil.GetCounterSetting(counterSpec, 'class') if self._CounterClass is None: raise ProvUtil.InvalidCounterSpecification('Builtin metric spec missing "class"') self._CounterClass = self._CounterClass.lower() if self._CounterClass not in _builtIns: raise ProvUtil.InvalidCounterSpecification('Unknown Builtin class {0}'.format(self._CounterClass)) builtin_raw_counter_name = ProvUtil.GetCounterSetting(counterSpec, 'counter') if builtin_raw_counter_name is None: raise ProvUtil.InvalidCounterSpecification('Builtin metric spec missing "counter"') builtin_counter_name = builtin_raw_counter_name.lower() if builtin_counter_name not in _builtIns[self._CounterClass]: raise ProvUtil.InvalidCounterSpecification( 'Counter {0} not in builtin class {1}'.format(builtin_raw_counter_name, self._CounterClass)) self._Counter = _builtIns[self._CounterClass][builtin_counter_name] self._Condition = ProvUtil.GetCounterSetting(counterSpec, 'condition') self._Label = ProvUtil.GetCounterSetting(counterSpec, 'counterSpecifier') if self._Label is None: raise ProvUtil.InvalidCounterSpecification( 'No counterSpecifier set for builtin {1} {0}'.format(self._Counter, self._CounterClass)) self._SampleRate = ProvUtil.GetCounterSetting(counterSpec, 'sampleRate') def is_type(self, t): """ Returns True if the metric is of the specified type. :param t: The name of the metric type to be checked :return bool: """ return self._Type == t.lower() def class_name(self): return self._CounterClass def counter_name(self): return self._Counter def condition(self): return self._Condition def label(self): return self._Label def sample_rate(self): """ Determine how often this metric should be retrieved. If the metric didn't define a sample period, return the default. :return int: Number of seconds between collecting samples of this metric. """ if self._SampleRate is None: return _defaultSampleRate else: return ProvUtil.IntervalToSeconds(self._SampleRate) def AddMetric(counter_spec): """ Add a metric to the list of metrics to be collected. :param counter_spec: The specification of a builtin metric. :return: the generated local-table name in mdsd into which this metric will be fetched, or None """ global _metrics, _eventNames try: metric = BuiltinMetric(counter_spec) except ProvUtil.ParseException as ex: print "Couldn't create metric: ", ex return None # (class, instanceId, sampleRate) -> [ metric ] # Given a class, instance within that class, and sample rate, we have a list of the requested metrics # matching those constraints. For that set of constraints, we also have a common eventName, the local # table where we store the collected metrics. key = (metric.class_name(), metric.condition(), metric.sample_rate()) if key not in _eventNames: _eventNames[key] = ProvUtil.MakeUniqueEventName('builtin') _metrics[key].append(metric) return _eventNames[key] def UpdateXML(doc): """ Add to the mdsd XML the minimal set of OMI queries which will retrieve the metrics requested via AddMetric(). This provider doesn't need any configuration external to mdsd; if it did, that would be generated here as well. :param doc: XML document object to be updated :return: None """ global _metrics, _eventNames, _omiClassName for group in _metrics: (class_name, condition_clause, sample_rate) = group if not condition_clause: condition_clause = default_condition(class_name) columns = [] mappings = [] for metric in _metrics[group]: omi_name = metric.counter_name() scale = _scaling[class_name][omi_name] columns.append(omi_name) mappings.append('{2}'.format(omi_name, scale, metric.label())) column_string = ','.join(columns) if condition_clause: cql_query = quoteattr("SELECT {0} FROM {1} WHERE {2}".format(column_string, _omiClassName[class_name], condition_clause)) else: cql_query = quoteattr("SELECT {0} FROM {1}".format(column_string, _omiClassName[class_name])) query = ''' {mappings} '''.format( qry=cql_query, evname=quoteattr(_eventNames[group]), columns=quoteattr(column_string), rate=sample_rate, mappings='\n '.join(mappings) ) XmlUtil.addElement(doc, 'Events/OMI', ET.fromstring(query)) return ================================================ FILE: Diagnostic/Providers/__init__.py ================================================ # Providers module package ================================================ FILE: Diagnostic/README.md ================================================ # [DEPRECATED] Linux Azure Diagnostic (LAD) Extension > :warning: The Azure Diagnostic extension has been **deprecated** and has no support as of **March 31, 2026.** If you use the Azure Diagnostic extension to collect data, [migrate now to the new Azure Monitor agent](https://learn.microsoft.com/en-us/azure/azure-monitor/agents/azure-monitor-agent-migration-wad-lad). Allow the owner of a Linux-based Azure Virtual Machine to obtain diagnostic data. Current version is 3.0.129. Linux Azure Diagnostic (LAD) extension version 3.0 is released with the following changes: - Fully configurable Azure Portal metrics, including a broader set of metrics to choose from. - Syslog message collection is now opt-in (off by default), and customers can selectively pick and choose syslog facilities and minimum severities of their interests. - Customers can now use CLI to configure their Azure Linux VMs for Azure Portal VM metrics charting experiences. - Customers can now send any metrics and logs as Azure EventHubs events (additional Azure EventHubs charges may apply). - Customers can also store any metrics and logs in Azure Storage JSON blobs (additional Azure Storage charges may apply). LAD 3.0 is NOT compatible with LAD 2.3. Users of LAD 2.3 must first uninstall that extension before installing LAD 3.0. LAD 3.0 is installed and configured via Azure CLI, Azure PowerShell cmdlets, or Azure Resource Manager templates. The Azure Portal controls installation and configuration of LAD 2.3 only. The Azure Metrics UI can display performance counters collected by either version of LAD. Please refer to [this document](https://docs.microsoft.com/azure/virtual-machines/linux/diagnostic-extension) for more details on configuring and using LAD 3.0. The tests folder contains [a sample JSON configuration](https://raw.githubusercontent.com/Azure/azure-linux-extensions/master/Diagnostic/tests/lad_2_3_compatible_portal_pub_settings.json) which sets LAD 3.0 to collecting exactly the same metrics and logs as the default configuration for LAD 2.3 collected. ## Supported Linux Distributions List of supported Linux distributions is on https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#supported-linux-distributions ## Debug - The status of the extension is reported back to Azure so that user can see the status on Azure Portal - The operation log of the extension is `/var/log/azure/Microsoft.Azure.Diagnostics.LinuxDiagnostic//` directory. [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: Diagnostic/Utils/LadDiagnosticUtil.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above # copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT # LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT # SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. # Get elements from DiagnosticsMonitorConfiguration in LadCfg based on element name def getDiagnosticsMonitorConfigurationElement(ladCfg, elementName): if ladCfg and 'diagnosticMonitorConfiguration' in ladCfg: if elementName in ladCfg['diagnosticMonitorConfiguration']: return ladCfg['diagnosticMonitorConfiguration'][elementName] return None # Get fileCfg form FileLogs in LadCfg def getFileCfgFromLadCfg(ladCfg): fileLogs = getDiagnosticsMonitorConfigurationElement(ladCfg, 'fileLogs') if fileLogs and 'fileLogConfiguration' in fileLogs: return fileLogs['fileLogConfiguration'] return None # Get resource Id from LadCfg def getResourceIdFromLadCfg(ladCfg): metricsConfiguration = getDiagnosticsMonitorConfigurationElement(ladCfg, 'metrics') if metricsConfiguration and 'resourceId' in metricsConfiguration: return metricsConfiguration['resourceId'] return None # Get event volume from LadCfg def getEventVolumeFromLadCfg(ladCfg): return getDiagnosticsMonitorConfigurationElement(ladCfg, 'eventVolume') # Get default sample rate from LadCfg def getDefaultSampleRateFromLadCfg(ladCfg): if ladCfg and 'sampleRateInSeconds' in ladCfg: return ladCfg['sampleRateInSeconds'] return None def getPerformanceCounterCfgFromLadCfg(ladCfg): """ Return the array of metric definitions :param ladCfg: :return: array of metric definitions """ performanceCounters = getDiagnosticsMonitorConfigurationElement(ladCfg, 'performanceCounters') if performanceCounters and 'performanceCounterConfiguration' in performanceCounters: return performanceCounters['performanceCounterConfiguration'] return None def getAggregationPeriodsFromLadCfg(ladCfg): """ Return an array of aggregation periods as specified. If nothing appears in the config, default PT1H :param ladCfg: :return: array of ISO 8601 intervals :rtype: List(str) """ results = [] metrics = getDiagnosticsMonitorConfigurationElement(ladCfg, 'metrics') if metrics and 'metricAggregation' in metrics: for item in metrics['metricAggregation']: if 'scheduledTransferPeriod' in item: # assert isinstance(item['scheduledTransferPeriod'], str) results.append(item['scheduledTransferPeriod']) return results def getSinkList(feature_config): """ Returns the list of sink names to which all data should be forwarded, according to this config :param feature_config: The JSON config for a feature (e.g. the struct for "performanceCounters" or "syslogEvents") :return: the list of names; might be an empty list :rtype: [str] """ if feature_config and 'sinks' in feature_config and feature_config['sinks']: return [sink_name.strip() for sink_name in feature_config['sinks'].split(',')] return [] def getFeatureWideSinksFromLadCfg(ladCfg, feature_name): """ Returns the list of sink names to which all data for the given feature should be forwarded :param ladCfg: The ladCfg JSON config :param str feature_name: Name of the feature. Expected to be "performanceCounters" or "syslogEvents" :return: the list of names; might be an empty list :rtype: [str] """ return getSinkList(getDiagnosticsMonitorConfigurationElement(ladCfg, feature_name)) class SinkConfiguration: def __init__(self): self._sinks = {} def insert_from_config(self, json): """ Walk through the sinksConfig JSON object and add all sinks within it. Every accepted sink is guaranteed to have a 'name' and 'type' element. :param json: A hash holding the body of a sinksConfig object :return: A string containing warning messages, or an empty string """ msgs = [] if json and 'sink' in json: for sink in json['sink']: if 'name' in sink and 'type' in sink: self._sinks[sink['name']] = sink else: msgs.append('Ignoring invalid sink definition {0}'.format(sink)) return '\n'.join(msgs) def get_sink_by_name(self, sink_name): """ Return the JSON object defining a particular sink. :param sink_name: string name of sink :return: JSON object or None """ if sink_name in self._sinks: return self._sinks[sink_name] return None def get_all_sink_names(self): """ Return a list of all names of defined sinks. :return: list of names """ return self._sinks.keys() def get_sinks_by_type(self, sink_type): """ Return a list of all names of defined sinks. :return: list of names """ return [self._sinks[name] for name in self._sinks if self._sinks[name]['type'] == sink_type] ================================================ FILE: Diagnostic/Utils/ProviderUtil.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import re from collections import defaultdict def GetCounterSetting(counter_spec, name): """ Retrieve a particular setting from a counter specification; if that setting is not present, return None. :param counter_spec: A dict of mappings from the name of a setting to its associated value. :param name: The name of the setting of interest. :return: Either the value of the setting (if present in counterSpec) or None. """ if name in counter_spec: return counter_spec[name] return None def IntervalToSeconds(specified_interval): """ Convert an ISO8601 duration string (e.g. PT5M, PT1H30M, PT30S) to a number of seconds. :param specified_interval: ISO8601 duration string. Must not include units larger than Hours. :return: An integer number of seconds. Raises ValueError if the duration string is syntactically invalid or beyond the supported range. """ interval = specified_interval.upper() if interval[0] != 'P': raise ValueError('"{0}" is not an IS8601 duration string'.format(interval)) if interval[1] != 'T': raise ValueError('IS8601 durations based on days or larger intervals are not supported: "{0}"'.format(interval)) seconds = 0 matches = re.findall(r'(\d+)(S|M|H)', interval[2:].upper()) for qty, unit in matches: qty = int(qty) if unit == 'S': seconds += qty elif unit == 'M': seconds += qty * 60 elif unit == 'H': seconds += qty * 3600 if 0 == seconds: raise ValueError('Could not parse interval specification "{0}"'.format(specified_interval)) return seconds _EventNameUniquifiers = defaultdict(int) def MakeUniqueEventName(prefix): """ Generate a unique event name given a prefix string. :param prefix: The prefix for the unique name. :return: The unique name, with prefix. """ _EventNameUniquifiers[prefix] += 1 return '{0}{1:0>6}'.format(prefix, _EventNameUniquifiers[prefix]) class ParseException(Exception): pass class UnexpectedCounterType(ParseException): pass class InvalidCounterSpecification(ParseException): pass ================================================ FILE: Diagnostic/Utils/XmlUtil.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of # the Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS # OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import xml.etree.ElementTree as ET def setXmlValue(xml,path,property,value,selector=[]): elements = xml.findall(path) for element in elements: if selector and element.get(selector[0])!=selector[1]: continue if not property: element.text = value elif not element.get(property) or len(element.get(property))==0 : element.set(property,value) def getXmlValue(xml,path,property): element = xml.find(path) if element is not None: return element.get(property) def addElement(xml,path,el,selector=[],addOnlyOnce=False): elements = xml.findall(path) for element in elements: if selector and element.get(selector[0])!=selector[1]: continue element.append(el) if addOnlyOnce: return def createElement(schema): return ET.fromstring(schema) def removeElement(tree, parent_path, removed_element_name): parents = tree.findall(parent_path) for parent in parents: element = parent.find(removed_element_name) while element is not None: parent.remove(element) element = parent.find(removed_element_name) ================================================ FILE: Diagnostic/Utils/__init__.py ================================================ # Providers module package ================================================ FILE: Diagnostic/Utils/imds_util.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import datetime import urllib2 import time import traceback def get_imds_data(node, json=True): """ Query IMDS endpoint for instance metadata and return the response as a Json string. :param str node: Instance metadata node we are querying about :param bool json: Indicates whether to query for Json output or not :return: Queried IMDS result in string :rtype: str """ if not node: return None separator = '' if node[0] == '/' else '/' imds_url = 'http://169.254.169.254{0}{1}{2}'.format( separator, node, '?format=json&api-version=latest_internal' if json else '') imds_headers = {'Metadata': 'True'} req = urllib2.Request(url=imds_url, headers=imds_headers) resp = urllib2.urlopen(req) data = resp.read() data_str = data.decode('utf-8') return data_str class ImdsLogger: """ Periodically probes IMDS endpoint and log the result as WALA events. """ def __init__(self, ext_name, ext_ver, ext_op_type, ext_event_logger, ext_logger=None, imds_data_getter=get_imds_data, logging_interval_in_minutes=60): """ Constructor :param str ext_name: Extension name (e.g., hutil.get_name()) :param str ext_ver: Extension version (e.g., hutil.get_version()) :param str ext_op_type: Extension operation type (e.g., HeartBeat) :param ext_event_logger: Extension event logger (e.g., waagent.AddExtensionEvent) :param ext_logger: Extension message logger (e.g., hutil.log) :param imds_data_getter: IMDS data getter function (e.g., get_imds_data) :param int logging_interval_in_minutes: Logging interval in minutes """ self._ext_name = ext_name self._ext_ver = ext_ver self._ext_op_type = ext_op_type self._ext_logger = ext_logger # E.g., hutil.log self._ext_event_logger = ext_event_logger # E.g., waagent.AddExtensionEvent self._last_log_time = datetime.datetime.fromordinal(1) self._imds_data_getter = imds_data_getter self._logging_interval = datetime.timedelta(minutes=logging_interval_in_minutes) def _ext_log_if_enabled(self, msg): """ Log an extension message if logger is specified. :param str msg: Message to log :return: None """ if self._ext_logger: self._ext_logger(msg) def log_imds_data_if_right_time(self, log_as_ext_event=False): """ Query and log IMDS data if it's right time to do so. :param bool log_as_ext_event: Indicates whether to log IMDS data as a waagent/extension event. :return: None """ now = datetime.datetime.now() if now < self._last_log_time + self._logging_interval: return try: imds_data = self._imds_data_getter('/metadata/instance/') except Exception as e: self._ext_log_if_enabled('Exception occurred while getting IMDS data: {0}\n' 'stacktrace: {1}'.format(e, traceback.format_exc())) imds_data = '{0}'.format(e) msg = 'IMDS instance data = {0}'.format(imds_data) if log_as_ext_event: self._ext_event_logger(name=self._ext_name, op=self._ext_op_type, isSuccess=True, version=self._ext_ver, message=msg) self._ext_log_if_enabled(msg) self._last_log_time = now if __name__ == '__main__': def fake_get_imds_data(node, json=True): result = 'fake_get_imds_data(node="{0}", json="{1}")'.format(node, json) print result return result def default_ext_logger(msg): print 'default_ext_logger(msg="{0}")'.format(msg) def default_ext_event_logger(*args, **kwargs): print 'default_ext_event_logger(*args, **kwargs)' print 'args:' for arg in args: print arg print 'kwargs:' for k in kwargs: print('"{0}"="{1}"'.format(k, kwargs[k])) imds_logger = ImdsLogger('Microsoft.OSTCExtensions.LinuxDiagnostic', '2.3.9021', 'Heartbeat', ext_logger=default_ext_logger, ext_event_logger=default_ext_event_logger, imds_data_getter=fake_get_imds_data, logging_interval_in_minutes=1) start_time = datetime.datetime.now() done = False while not done: now = datetime.datetime.now() print 'Test loop iteration starting at {0}'.format(now) imds_logger.log_imds_data_if_right_time() if now >= start_time + datetime.timedelta(minutes=2): done = True else: print 'Sleeping 10 seconds' time.sleep(10) ================================================ FILE: Diagnostic/Utils/lad_exceptions.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. class LadLoggingConfigException(Exception): """ Custom exception class for LAD logging (syslog & filelogs) config errors """ pass class LadPerfCfgConfigException(Exception): """ Custom exception class for LAD perfCfg (raw OMI queries) config errors """ pass ================================================ FILE: Diagnostic/Utils/lad_ext_settings.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import base64 import copy import json import traceback import Utils.LadDiagnosticUtil as LadUtil import Utils.XmlUtil as XmlUtil class ExtSettings(object): """ Wrapper class around any generic Azure extension settings Json objects. TODO This class may better go to some place else (e.g., HandlerUtil.py). """ def __init__(self, handler_settings): """ Constructor :param handler_settings: Json object (dictionary) decoded from the extension settings Json string. """ self._handler_settings = handler_settings if handler_settings else {} public_settings = self._handler_settings.get('publicSettings') self._public_settings = public_settings if public_settings else {} protected_settings = self._handler_settings.get('protectedSettings') self._protected_settings = protected_settings if protected_settings else {} def get_handler_settings(self): """ Hanlder settings (Json dictionary) getter :return: Handler settings Json object """ return self._handler_settings def has_public_config(self, key): """ Determine if a particular setting is present in the public config :param str key: The setting to look for :return: True if the setting is present (regardless of its value) :rtype: bool """ return key in self._public_settings def read_public_config(self, key): """ Return the value of a particular public config setting :param str key: The setting to retrieve :return: The value of the setting if present; an empty string (*not* None) if the setting is not present :rtype: str """ if key in self._public_settings: return self._public_settings[key] return '' def read_protected_config(self, key): """ Return the value of a particular protected config setting :param str key: The setting to retrive :return: The value of the setting if present; an empty string (*not* None) if the setting is not present :rtype: str """ if key in self._protected_settings: return self._protected_settings[key] return '' class LadExtSettings(ExtSettings): """ LAD-specific extension settings object that supports LAD-specific member functions """ def __init__(self, handler_settings): super(LadExtSettings, self).__init__(handler_settings) def redacted_handler_settings(self): """ Get handler settings in string after redacting secrets (for diagnostic purpose w/ Geneva telemetry) :rtype: str :return: String for the handler settings JSON object with secrets redacted. """ # The logic below could have been a general-purpose JSON tree walker, but since the specific # knowledge of where secrets are needs be applied anyway, it's coded for this specific schema anyway. # Secrets are stored only in the following paths: .storageAccountSasToken, and .sinksConfig.sink[].sasURL. # LAD 2.3 used to support storageAccountKey; although LAD 3.0 does not support it, some users might mistakenly # supply it. We redact it, if present, even though we're going to throw an error later on; the protected # settings are logged before we inspect them to pull out the credentials. # Get and work on a copy of the handler settings dict. Note that it must be a deep copy! # dict(self.get_handler_settings()) doesn't work! handler_settings = copy.deepcopy(self.get_handler_settings()) protected_settings = handler_settings['protectedSettings'] if protected_settings: if 'storageAccountSasToken' in protected_settings: protected_settings['storageAccountSasToken'] = 'REDACTED_SECRET' if 'storageAccountKey' in protected_settings: protected_settings['storageAccountKey'] = 'REDACTED_SECRET' if 'sinksConfig' in protected_settings and 'sink' in protected_settings['sinksConfig']: for each_sink_dict in protected_settings['sinksConfig']['sink']: if 'sasURL' in each_sink_dict: each_sink_dict['sasURL'] = 'REDACTED_SECRET' return json.dumps(handler_settings, sort_keys=True) def log_ext_settings_with_secrets_redacted(self, logger_log, logger_err): """ Log entire extension settings with secrets redacted. This was introduced to help ourselves find any misconfiguration issues related to the storageAccountEndPoint easier, and later extended to log all extension settings with secrets redacted, for better diagnostics. :param logger_log: Normal logging function (e.g., hutil.log) :param logger_err: Error logging function (e.g., hutil.error) :return: None """ try: msg = "LAD settings with secrets redacted: {0}".format( self.redacted_handler_settings()) logger_log(msg) except Exception as e: logger_err("Failed to log LAD settings with secrets redacted. Error:{0}\n" "Stacktrace: {1}".format(e, traceback.format_exc())) def get_resource_id(self): """ Try to get resourceId from LadCfg. If not present, try to fetch from xmlCfg. """ lad_cfg = self.read_public_config('ladCfg') resource_id = LadUtil.getResourceIdFromLadCfg(lad_cfg) if not resource_id: encoded_xml_cfg = self.read_public_config('xmlCfg').strip() if encoded_xml_cfg: xml_cfg = base64.b64decode(encoded_xml_cfg) resource_id = XmlUtil.getXmlValue(XmlUtil.createElement(xml_cfg), 'diagnosticMonitorConfiguration/metrics', 'resourceId') # Azure portal uses xmlCfg which contains WadCfg which is pascal case # Currently we will support both casing and deprecate one later if not resource_id: resource_id = XmlUtil.getXmlValue(XmlUtil.createElement(xml_cfg), 'DiagnosticMonitorConfiguration/Metrics', 'resourceId') return resource_id def get_syslogEvents_setting(self): """ Get 'ladCfg/syslogEvents' setting from LAD 3.0 public settings. :return: A dictionary of syslog facility and minSeverity to monitor/ Refer to README.md for more details. """ return LadUtil.getDiagnosticsMonitorConfigurationElement(self.read_public_config('ladCfg'), 'syslogEvents') def get_fileLogs_setting(self): """ Get 'fileLogs' setting from LAD 3.0 public settings. :return: List of dictionaries specifying file to monitor and Azure table name for destinations of the monitored file. Refer to README.md for more details """ return self.read_public_config('fileLogs') def get_mdsd_trace_option(self): """ Return traceFlags, if any, from public config :rtype: str :return: trace flags or an empty string """ flags = self.read_public_config('traceFlags') if flags: return " -T {0}".format(flags) else: return "" ================================================ FILE: Diagnostic/Utils/lad_logging_config.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. from xml.etree import ElementTree as ET import Utils.LadDiagnosticUtil as LadUtil from Utils.lad_exceptions import LadLoggingConfigException import Utils.mdsd_xml_templates as mxt from Utils.omsagent_util import get_syslog_ng_src_name syslog_src_name = 'mdsd.syslog' class LadLoggingConfig: """ Utility class for obtaining syslog (rsyslog or syslog-ng) configurations for use with fluentd (currently omsagent), and corresponding omsagent & mdsd configurations, based on the LAD 3.0 syslog config schema. This class also generates omsagent (fluentd) config for LAD 3.0's fileLogs settings (using the fluentd tail plugin). """ def __init__(self, syslogEvents, fileLogs, sinksConfig, pkey_path, cert_path, encrypt_secret): """ Constructor to receive/store necessary LAD settings for the desired configuration generation. :param dict syslogEvents: LAD 3.0 "ladCfg" - "syslogEvents" JSON object, or a False object if it's not given in the extension settings. An example is as follows: "ladCfg": { "syslogEvents" : { "sinks": "SyslogSinkName0", "syslogEventConfiguration": { "facilityName1": "minSeverity1", "facilityName2": "minSeverity2" } } } Only the JSON object corresponding to "syslogEvents" key should be passed. facilityName1/2 is a syslog facility name (e.g., "LOG_USER", "LOG_LOCAL0"). minSeverity1/2 is a syslog severity level (e.g., "LOG_ERR", "LOG_CRIT") or "NONE". "NONE" means no logs from the facility will be captured (thus it's equivalent to not specifying the facility at all). :param dict fileLogs: LAD 3.0 "fileLogs" JSON object, or a False object if it's not given in the ext settings. An example is as follows: "fileLogs": { "fileLogConfiguration": [ { "file": "/var/log/mydaemonlog", "table": "MyDaemonEvents", "sinks": "FilelogSinkName1", }, { "file": "/var/log/myotherdaemonelog", "table": "MyOtherDaemonEvents", "sinks": "FilelogSinkName2" } ] } Only the JSON array corresponding to "fileLogConfiguration" key should be passed. "file" is the full path of the log file to be watched and captured. "table" is for the Azure storage table into which the lines of the watched file will be placed (one row per line). :param LadUtil.SinkConfiguration sinksConfig: SinkConfiguration object that's created out of "sinksConfig" LAD 3.0 JSON setting. Refer to LadUtil.SinkConfiguraiton documentation. :param str pkey_path: Path to the VM's private key that should be passed to mdsd XML for decrypting encrypted secrets (EH SAS URL) :param str cert_path: Path to the VM's certificate that should be used to encrypt secrets (EH SAS URL) :param encrypt_secret: Function to encrypt a secret (string, 2nd param) with the provided cert path param (1st) """ self._syslogEvents = syslogEvents self._fileLogs = fileLogs self._sinksConfig = sinksConfig self._pkey_path = pkey_path self._cert_path = cert_path self._encrypt_secret = encrypt_secret self._fac_sev_map = None try: # Create facility-severity map. E.g.: { "LOG_USER" : "LOG_ERR", "LOG_LOCAL0", "LOG_CRIT" } if self._syslogEvents: self._fac_sev_map = self._syslogEvents['syslogEventConfiguration'] self._syslog_disabled = not self._fac_sev_map # A convenience predicate if self._fileLogs: # Convert the 'fileLogs' JSON object array into a Python dictionary of 'file' - 'table' # E.g., [{ 'file': '/var/log/mydaemonlog1', 'table': 'MyDaemon1Events', 'sinks': 'File1Sink'}, # { 'file': '/var/log/mydaemonlog2', 'table': 'MyDaemon2Events', 'sinks': 'File2SinkA,File2SinkB'}] self._file_table_map = dict([(entry['file'], entry['table'] if 'table' in entry else '') for entry in self._fileLogs]) self._file_sinks_map = dict([(entry['file'], entry['sinks'] if 'sinks' in entry else '') for entry in self._fileLogs]) self._rsyslog_config = None self._syslog_ng_config = None self._mdsd_syslog_config = None self._mdsd_telegraf_config = None self._mdsd_filelog_config = None except KeyError as e: raise LadLoggingConfigException("Invalid setting name provided (KeyError). Exception msg: {0}".format(e)) def get_rsyslog_config(self): """ Returns rsyslog config (for use with omsagent) that corresponds to the syslogEvents or the syslogCfg JSON object given in the construction parameters. :rtype: str :return: rsyslog config string that should be appended to /etc/rsyslog.d/95-omsagent.conf (new rsyslog) or to /etc/rsyslog.conf (old rsyslog) """ if not self._rsyslog_config: if self._syslog_disabled: self._rsyslog_config = '' else: # Generate/save/return rsyslog config string for the facility-severity pairs. # E.g.: "user.err @127.0.0.1:%SYSLOG_PORT%\nlocal0.crit @127.0.0.1:%SYSLOG_PORT%\n' self._rsyslog_config = \ '\n'.join('{0}.{1} @127.0.0.1:%SYSLOG_PORT%'.format(syslog_name_to_rsyslog_name(fac), syslog_name_to_rsyslog_name(sev)) for fac, sev in self._fac_sev_map.iteritems()) + '\n' return self._rsyslog_config def get_syslog_ng_config(self): """ Returns syslog-ng config (for use with omsagent) that corresponds to the syslogEvents or the syslogCfg JSON object given in the construction parameters. :rtype: str :return: syslog-ng config string that should be appended to /etc/syslog-ng/syslog-ng.conf """ if not self._syslog_ng_config: if self._syslog_disabled: self._syslog_ng_config = '' else: # Generate/save/return syslog-ng config string for the facility-severity pairs. # E.g.: "log { source(src); filter(f_LAD_oms_f_user); filter(f_LAD_oms_ml_err); destination(d_LAD_oms); };\nlog { source(src); filter(f_LAD_oms_f_local0); filter(f_LAD_oms_ml_crit); destination(d_LAD_oms); };\n" self._syslog_ng_config = \ '\n'.join('log {{ source({0}); filter(f_LAD_oms_f_{1}); filter(f_LAD_oms_ml_{2}); ' 'destination(d_LAD_oms); }};'.format(get_syslog_ng_src_name(), syslog_name_to_rsyslog_name(fac), syslog_name_to_rsyslog_name(sev)) for fac, sev in self._fac_sev_map.iteritems()) + '\n' return self._syslog_ng_config def parse_pt_duration(self, duration): """ Convert the ISO8601 Time Duration into seconds. for ex PT2H3M20S will be 7400 seconds :param duration: The ISO8601 duration string to be converted into seconds """ total_seconds = 0 count = "" for ch in duration: if ch.lower() == 'h': total_seconds += int(count)*3600 count = "" elif ch.lower() == 'm': total_seconds += int(count)*60 count = "" elif ch.lower() == 's': total_seconds += int(count) count = "" elif ch in ["0","1","2","3","4","5","6","7","8","9"]: count += ch return str(total_seconds)+"s" def parse_lad_perf_settings(self, ladconfig): """ Parses the LAD json config to create a list of entries per metric along with it's configuration as required by telegraf config parser. See example below - :param ladconfig: The lad json config element Sample OMI metric json config can be of two types, taken from .settings file It can have sampleRate key, if not then it defaults to sampleRateInSeconds key in the larger lad_cfg element { u'counterSpecifier': u'/builtin/network/packetstransmitted', u'counter': u'packetstransmitted', u'class': u'network', u'sampleRate': u'PT15S', u'type': u'builtin', u'annotation': [{ u'locale': u'en-us', u'displayName': u'Packets sent' }], u'unit': u'Count' } "annotation": [ { "displayName": "Disk write guest OS", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "writebytespersecond", "counterSpecifier": "/builtin/disk/writebytespersecond", "type": "builtin", "unit": "BytesPerSecond" }, """ if not ladconfig: return [] data = [] default_sample_rate = "15s" #Lowest supported time interval if "sampleRateInSeconds" in ladconfig and ladconfig["sampleRateInSeconds"] != "": default_sample_rate = str(ladconfig["sampleRateInSeconds"]) + "s" #Example, converting 15 to 15s if 'diagnosticMonitorConfiguration' in ladconfig and "performanceCounters" in ladconfig['diagnosticMonitorConfiguration']: data = ladconfig['diagnosticMonitorConfiguration']["performanceCounters"] else: return [] if "performanceCounterConfiguration" not in data or len(data["performanceCounterConfiguration"]) == 0: return [] parsed_settings = [] perfconf = data["performanceCounterConfiguration"] for item in perfconf: counter = {} counter["displayName"] = item["class"].strip().lower() + "->" + item["annotation"][0]["displayName"].strip().lower() if "sampleRate" in item: counter["interval"] = self.parse_pt_duration(item["sampleRate"]) #Converting ISO8601 to seconds string else: counter["interval"] = default_sample_rate parsed_settings.append(counter) """ Sample output after parsing the OMI metric [ { "displayName" : "Network->Packets sent", "interval" : "15s" }, ] """ return parsed_settings def get_mdsd_syslog_config(self, disableStorageAccount = False): """ Get mdsd XML config string for syslog use with omsagent in LAD 3.0. :rtype: str :return: XML string that should be added to the mdsd config XML tree for syslog use with omsagent in LAD 3.0. """ if not self._mdsd_syslog_config: self._mdsd_syslog_config = self.__generate_mdsd_syslog_config(disableStorageAccount) return self._mdsd_syslog_config def __generate_mdsd_syslog_config(self, disableStorageAccount = False): """ Helper method to generate oms_mdsd_syslog_config """ if self._syslog_disabled: return '' # For basic syslog conf (single dest table): Source name is unified as 'mdsd.syslog' and # dest table (eventName) is 'LinuxSyslog'. This is currently the only supported syslog conf scheme. syslog_routeevents = '' if not disableStorageAccount: syslog_routeevents = mxt.per_RouteEvent_tmpl.format(event_name='LinuxSyslog', opt_store_type='') # Add RouteEvent elements for specified "sinks" for "syslogEvents" feature # Also add EventStreamingAnnotation for EventHub sinks syslog_eh_urls = '' for sink_name in LadUtil.getSinkList(self._syslogEvents): if sink_name == 'LinuxSyslog': raise LadLoggingConfigException("'LinuxSyslog' can't be used as a sink name. " "It's reserved for default Azure Table name for syslog events.") routeevent, eh_url = self.__generate_routeevent_and_eh_url_for_extra_sink(sink_name, syslog_src_name) syslog_routeevents += routeevent syslog_eh_urls += eh_url mdsd_event_source = '' if syslog_routeevents: # Do not add MdsdEventSource element if there's no associated RouteEvent generated. mdsd_event_source = mxt.per_MdsdEventSource_tmpl.format(source=syslog_src_name, routeevents=syslog_routeevents) return mxt.top_level_tmpl_for_logging_only.format( sources=mxt.per_source_tmpl.format(name=syslog_src_name), events=mdsd_event_source, eh_urls=syslog_eh_urls) def get_mdsd_telegraf_config(self, namespaces): """ Get mdsd XML config string for telegraf use with mdsd in LAD 3.0. This method is called during config generation to create source tags for mdsd xml :param namespaces: The list of telegraf plugins being used to source the metrics requested by the user :rtype: str :return: XML string that should be added to the mdsd config XML tree for telegraf use with mdsd in LAD 3.0. """ if not self._mdsd_telegraf_config: self._mdsd_telegraf_config = self.__generate_mdsd_telegraf_config(namespaces) return self._mdsd_telegraf_config def __generate_mdsd_telegraf_config(self, namespaces): """ Helper method to generate mdsd_telegraf_config """ if len(namespaces) == 0: return '' telegraf_sources = "" for plugin in namespaces: # # For telegraf conf we create a Source for each of the measurements(plugins) sent from telegraf lad_specific_storage_plugin = "storage-" + plugin telegraf_sources += mxt.per_source_tmpl.format(name=lad_specific_storage_plugin) return mxt.top_level_tmpl_for_logging_only.format(sources=telegraf_sources, events="", eh_urls="") def __generate_routeevent_and_eh_url_for_extra_sink(self, sink_name, src_name): """ Helper method to generate one RouteEvent element for each extra sink given. Also generates an EventStreamingAnnotation element for EventHub sinks. :param str sink_name: The name of the sink for the RouteEvent. :param str src_name: The name of the ingested source that should be used for EventStreamingAnnotation. :rtype str,str: :return: A pair of the XML RouteEvent element string for the sink and the EventHubStreamingAnnotation XML string. """ sink = self._sinksConfig.get_sink_by_name(sink_name) if not sink: raise LadLoggingConfigException('Sink name "{0}" is not defined in sinksConfig'.format(sink_name)) sink_type = sink['type'] if not sink_type: raise LadLoggingConfigException('Sink type for sink "{0}" is not defined in sinksConfig'.format(sink_name)) if sink_type == 'JsonBlob': return mxt.per_RouteEvent_tmpl.format(event_name=sink_name, opt_store_type='storeType="JsonBlob"'),\ '' # No EventStreamingAnnotation for JsonBlob elif sink_type == 'EventHub': if 'sasURL' not in sink: raise LadLoggingConfigException('sasURL is not specified for EventHub sink_name={0}'.format(sink_name)) # For syslog/filelogs (ingested events), the source name should be used for EventStreamingAnnotation name. eh_url = mxt.per_eh_url_tmpl.format(eh_name=src_name, key_path=self._pkey_path, enc_eh_url=self._encrypt_secret(self._cert_path, sink['sasURL'])) return '', eh_url # No RouteEvent for logging event's EventHub sink else: raise LadLoggingConfigException('{0} sink type (for sink_name={1}) is not supported'.format(sink_type, sink_name)) def get_mdsd_filelog_config(self): """ Get mdsd XML config string for filelog (tail) use with omsagent in LAD 3.0. :rtype: str :return: XML string that should be added to the mdsd config XML tree for filelog use with omsagent in LAD 3.0. """ if not self._mdsd_filelog_config: self._mdsd_filelog_config = self.__generate_mdsd_filelog_config() return self._mdsd_filelog_config def __generate_mdsd_filelog_config(self): """ Helper method to generate oms_mdsd_filelog_config """ if not self._fileLogs: return '' # Per-file source name is 'mdsd.filelog<.path.to.file>' where '<.path.to.file>' is a full path # with all '/' replaced by '.'. filelogs_sources = '' filelogs_mdsd_event_sources = '' filelogs_eh_urls = '' for file_key in sorted(self._file_table_map): if not self._file_table_map[file_key] and not self._file_sinks_map[file_key]: raise LadLoggingConfigException('Neither "table" nor "sinks" defined for file "{0}"'.format(file_key)) source_name = 'mdsd.filelog{0}'.format(file_key.replace('/', '.')) filelogs_sources += mxt.per_source_tmpl.format(name=source_name) per_file_routeevents = '' if self._file_table_map[file_key]: per_file_routeevents += mxt.per_RouteEvent_tmpl.format(event_name=self._file_table_map[file_key], opt_store_type='') if self._file_sinks_map[file_key]: for sink_name in self._file_sinks_map[file_key].split(','): routeevent, eh_url = self.__generate_routeevent_and_eh_url_for_extra_sink(sink_name, source_name) per_file_routeevents += routeevent filelogs_eh_urls += eh_url if per_file_routeevents: # Do not add MdsdEventSource element if there's no associated RouteEvent generated. filelogs_mdsd_event_sources += \ mxt.per_MdsdEventSource_tmpl.format(source=source_name, routeevents=per_file_routeevents) return mxt.top_level_tmpl_for_logging_only.format(sources=filelogs_sources, events=filelogs_mdsd_event_sources, eh_urls=filelogs_eh_urls) def get_fluentd_syslog_src_config(self): """ Get Fluentd's syslog source config that should be used for this LAD's syslog configs. :rtype: str :return: Fluentd config string that should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/syslog.conf (after replacing '%SYSLOG_PORT%' with the assigned/picked port number) """ fluentd_syslog_src_config = """ type syslog port %SYSLOG_PORT% bind 127.0.0.1 protocol_type udp include_source_host true tag mdsd.syslog # Generate fields expected for existing mdsd syslog collection schema. type record_transformer enable_ruby # Fields for backward compatibility with Azure Shoebox V1 (Table storage) Ignore "syslog" Facility ${tag_parts[2]} Severity ${tag_parts[3]} EventTime ${time.strftime('%Y-%m-%dT%H:%M:%S%z')} SendingHost ${record["source_host"]} Msg ${record["message"]} # Rename 'host' key, as mdsd will add 'Host' for Azure Table and it'll be confusing hostname ${record["host"]} remove_keys host,message,source_host # Renamed (duplicated) fields, so just remove """ return '' if self._syslog_disabled else fluentd_syslog_src_config def get_fluentd_filelog_src_config(self): """ Get Fluentd's filelog (tail) source config that should be used for this LAD's fileLogs settings. :rtype: str :return: Fluentd config string that should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/file.conf """ if not self._fileLogs: return '' fluentd_tail_src_config_template = """ # For all monitored files @type tail path {file_paths} pos_file /var/opt/microsoft/omsagent/LAD/tmp/filelogs.pos tag mdsd.filelog.* format none message_key Msg # LAD uses "Msg" as the field name # Add FileTag field (existing LAD behavior) @type record_transformer FileTag ${{tag_suffix[2]}} """ return fluentd_tail_src_config_template.format(file_paths=','.join(self._file_table_map.keys())) def get_fluentd_out_mdsd_config(self): """ Get Fluentd's out_mdsd output config that should be used for LAD. TODO This is not really syslog-specific, so should be moved outside from here. :rtype: str :return: Fluentd config string that should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/z_out_mdsd.conf """ fluentd_out_mdsd_config_template = """ # Output to mdsd type mdsd log_level warn djsonsocket /var/run/mdsd/lad_mdsd_djson.socket # Full path to mdsd dynamic json socket file acktimeoutms 5000 # max time in milli-seconds to wait for mdsd acknowledge response. If 0, no wait. {tag_regex_cfg_line} num_threads 1 buffer_chunk_limit 1000k buffer_type file buffer_path /var/opt/microsoft/omsagent/LAD/state/out_mdsd*.buffer buffer_queue_limit 128 flush_interval 10s retry_limit 3 retry_wait 10s """ tag_regex_cfg_line = '' if self._syslog_disabled \ else r""" mdsd_tag_regex_patterns [ "^mdsd\\.syslog" ] # fluentd tag patterns whose match will be used as mdsd source name """ return fluentd_out_mdsd_config_template.format(tag_regex_cfg_line=tag_regex_cfg_line) syslog_name_to_rsyslog_name_map = { # facilities 'LOG_AUTH': 'auth', 'LOG_AUTHPRIV': 'authpriv', 'LOG_CRON': 'cron', 'LOG_DAEMON': 'daemon', 'LOG_FTP': 'ftp', 'LOG_KERN': 'kern', 'LOG_LOCAL0': 'local0', 'LOG_LOCAL1': 'local1', 'LOG_LOCAL2': 'local2', 'LOG_LOCAL3': 'local3', 'LOG_LOCAL4': 'local4', 'LOG_LOCAL5': 'local5', 'LOG_LOCAL6': 'local6', 'LOG_LOCAL7': 'local7', 'LOG_LPR': 'lpr', 'LOG_MAIL': 'mail', 'LOG_NEWS': 'news', 'LOG_SYSLOG': 'syslog', 'LOG_USER': 'user', 'LOG_UUCP': 'uucp', # severities 'LOG_EMERG': 'emerg', 'LOG_ALERT': 'alert', 'LOG_CRIT': 'crit', 'LOG_ERR': 'err', 'LOG_WARNING': 'warning', 'LOG_NOTICE': 'notice', 'LOG_INFO': 'info', 'LOG_DEBUG': 'debug' } def syslog_name_to_rsyslog_name(syslog_name): """ Convert a syslog name (e.g., "LOG_USER") to the corresponding rsyslog name (e.g., "user") :param str syslog_name: A syslog name for a facility (e.g., "LOG_USER") or a severity (e.g., "LOG_ERR") :rtype: str :return: Corresponding rsyslog name (e.g., "user" or "error") """ if syslog_name == '*': # We accept '*' as a facility name (also as a severity name, though it's not required) # to allow customers to collect for reserved syslog facility numeric IDs (12-15) return '*' if syslog_name not in syslog_name_to_rsyslog_name_map: raise LadLoggingConfigException('Invalid syslog name given: {0}'.format(syslog_name)) return syslog_name_to_rsyslog_name_map[syslog_name] def copy_sub_elems(dst_xml, src_xml, path): """ Copy sub-elements of src_elem (XML) to dst_elem. :param xml.etree.ElementTree.ElementTree dst_xml: Python xml tree object to which sub-elements will be copied. :param xml.etree.ElementTree.ElementTree src_xml: Python xml tree object from which sub-elements will be copied. :param str path: The path of the element whose sub-elements will be copied. :return: None. dst_xml will be updated with copied sub-elements """ dst_elem = dst_xml.find(path) src_elem = src_xml.find(path) if src_elem is None: return for sub_elem in src_elem: dst_elem.append(sub_elem) def copy_source_mdsdevent_eh_url_elems(mdsd_xml_tree, mdsd_logging_xml_string): """ Copy MonitoringManagement/Schemas/Schema, MonitoringManagement/Sources/Source, MonitoringManagement/Events/MdsdEvents/MdsdEventSource elements, and MonitoringManagement/EventStreamingAnnotations /EventStreamingAnnontation elements from mdsd_rsyslog_xml_string to mdsd_xml_tree. Used to actually add generated rsyslog mdsd config XML elements to the mdsd config XML tree. :param xml.etree.ElementTree.ElementTree mdsd_xml_tree: Python xml.etree.ElementTree object that's generated from mdsd config XML template :param str mdsd_logging_xml_string: XML string containing the generated logging (syslog/filelog) mdsd config XML elements. See oms_syslog_mdsd_*_expected_xpaths member variables in test_lad_logging_config.py for examples in XPATHS format. :return: None. mdsd_xml_tree object will contain the added elements. """ if not mdsd_logging_xml_string: return mdsd_logging_xml_tree = ET.ElementTree(ET.fromstring(mdsd_logging_xml_string)) # Copy Source elements (sub-elements of Sources element) copy_sub_elems(mdsd_xml_tree, mdsd_logging_xml_tree, 'Sources') # Copy MdsdEventSource elements (sub-elements of Events/MdsdEvents element) copy_sub_elems(mdsd_xml_tree, mdsd_logging_xml_tree, 'Events/MdsdEvents') # Copy EventStreamingAnnotation elements (sub-elements of EventStreamingAnnotations element) copy_sub_elems(mdsd_xml_tree, mdsd_logging_xml_tree, 'EventStreamingAnnotations') ================================================ FILE: Diagnostic/Utils/mdsd_xml_templates.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # Various XML templates definitions for use in constructing mdsd XML config file. per_eh_url_tmpl = """ {enc_eh_url} """ top_level_tmpl_for_logging_only = """ {sources} {events} {eh_urls} """ per_source_tmpl = """ """ per_MdsdEventSource_tmpl = """ {routeevents} """ per_RouteEvent_tmpl = """ """ derived_event = """ """ lad_query = '' obo_field = '' # OMI is not used anymore entire_xml_cfg_tmpl = """ """ ================================================ FILE: Diagnostic/Utils/misc_helpers.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import os import tempfile import re import string import traceback import xml.dom.minidom import binascii from Utils.WAAgentUtil import waagent from Utils.lad_exceptions import LadLoggingConfigException def get_extension_operation_type(command): if re.match("^([-/]*)(enable)", command): return waagent.WALAEventOperation.Enable if re.match("^([-/]*)(daemon)", command): # LAD-specific extension operation (invoked from "./diagnostic.py -enable") return "Daemon" if re.match("^([-/]*)(install)", command): return waagent.WALAEventOperation.Install if re.match("^([-/]*)(disable)", command): return waagent.WALAEventOperation.Disable if re.match("^([-/]*)(uninstall)", command): return waagent.WALAEventOperation.Uninstall if re.match("^([-/]*)(update)", command): return waagent.WALAEventOperation.Update def wala_event_type_for_telemetry(ext_op_type): return "HeartBeat" if ext_op_type == "Daemon" else ext_op_type def get_storage_endpoints_with_account(account, endpoint_without_account): endpoint = endpoint_without_account if endpoint: parts = endpoint.split('//', 1) if len(parts) > 1: tableEndpoint = parts[0]+'//'+account+".table."+parts[1] blobEndpoint = parts[0]+'//'+account+".blob."+parts[1] else: tableEndpoint = 'https://'+account+".table."+parts[0] blobEndpoint = 'https://'+account+".blob."+parts[0] else: tableEndpoint = 'https://'+account+'.table.core.windows.net' blobEndpoint = 'https://'+account+'.blob.core.windows.net' return (tableEndpoint, blobEndpoint) def check_suspected_memory_leak(pid, logger_err): """ Check suspected memory leak of a process, by inspecting /proc//status's VmRSS value. :param pid: ID of the process we are checking. :param logger_err: Error logging function (e.g., hutil.error) :return (bool, int): Bool indicating whether memory leak is suspected. Int for memory usage in KB in true case. """ memory_leak_threshold_in_KB = 2000000 # Roughly 2GB. TODO: Make it configurable or automatically calculated memory_usage_in_KB = 0 memory_leak_suspected = False try: # Check /proc/[pid]/status file for "VmRSS" to find out the process's virtual memory usage # Note: "VmSize" for some reason starts out very high (>2000000) at this moment, so can't use that. with open("/proc/{0}/status".format(pid)) as proc_file: for line in proc_file: if line.startswith("VmRSS:"): # Example line: "VmRSS: 33904 kB" memory_usage_in_KB = int(line.split()[1]) memory_leak_suspected = memory_usage_in_KB > memory_leak_threshold_in_KB break except Exception as e: # Not to throw in case any statement above fails (e.g., invalid pid). Just log. logger_err("Failed to check memory usage of pid={0}.\nError: {1}\nTrace:\n{2}".format(pid, e, traceback.format_exc())) return memory_leak_suspected, memory_usage_in_KB class LadLogHelper(object): """ Various LAD log helper functions encapsulated here, so that we don't have to tag along all the parameters. """ def __init__(self, logger_log, logger_error, waagent_event_adder, status_reporter, ext_name, ext_ver): """ Constructor :param logger_log: Normal logging function (e.g., hutil.log) :param logger_error: Error logging function (e.g., hutil.error) :param waagent_event_adder: waagent event add function (waagent.AddExtensionEvent) :param status_reporter: waagent/extension status report function (hutil.do_status_report) :param ext_name: Extension name (hutil.get_name()) :param ext_ver: Extension version (hutil.get_extension_version()) """ self._logger_log = logger_log self._logger_error = logger_error self._waagent_event_adder = waagent_event_adder self._status_reporter = status_reporter self._ext_name = ext_name self._ext_ver = ext_ver def log_suspected_memory_leak_and_kill_mdsd(self, memory_usage_in_KB, mdsd_process, ext_op): """ Log suspected-memory-leak message both in ext logs and as a waagent event. :param memory_usage_in_KB: Memory usage in KB (to be included in the log) :param mdsd_process: Python Process object for the mdsd process to kill :param ext_op: Extension operation type to use for waagent event (waagent.WALAEventOperation.HeartBeat) :return: None """ memory_leak_msg = "Suspected mdsd memory leak (Virtual memory usage: {0}MB). " \ "Recycling mdsd to self-mitigate.".format(int((memory_usage_in_KB + 1023) / 1024)) self._logger_log(memory_leak_msg) # Add a telemetry for a possible statistical analysis self._waagent_event_adder(name=self._ext_name, op=ext_op, isSuccess=True, version=self._ext_ver, message=memory_leak_msg) mdsd_process.kill() def report_mdsd_dependency_setup_failure(self, ext_event_type, failure_msg): """ Report mdsd dependency setup failure to 3 destinations (ext log, status report, agent event) :param ext_event_type: Type of extension event being performed (e.g., 'HeartBeat') :param failure_msg: Dependency setup failure message to be added to the logs :return: None """ dependencies_err_log_msg = "Failed to set up mdsd dependencies: {0}".format(failure_msg) self._logger_error(dependencies_err_log_msg) self._status_reporter(ext_event_type, 'error', '1', dependencies_err_log_msg) self._waagent_event_adder(name=self._ext_name, op=ext_event_type, isSuccess=False, version=self._ext_ver, message=dependencies_err_log_msg) def log_and_report_failed_config_generation(self, ext_event_type, config_invalid_reason, redacted_handler_settings): """ Report failed config generation from configurator.generate_all_configs(). :param str ext_event_type: Type of extension event being performed (most likely 'HeartBeat') :param str config_invalid_reason: Msg from configurator.generate_all_configs() :param str redacted_handler_settings: JSON string for the extension's protected/public settings after redacting secrets in the protected settings. This is for logging to Geneva for diagnostic purposes. :return: None """ config_invalid_log = "Invalid config settings given: " + config_invalid_reason + \ ". Can't proceed, although this install/enable operation is reported as successful so " \ "the VM can complete successful startup." self._logger_log(config_invalid_log) self._status_reporter(ext_event_type, 'success', '0', config_invalid_log) self._waagent_event_adder(name=self._ext_name, op=ext_event_type, isSuccess=True, # Note this is True, because it is a user error. version=self._ext_ver, message="Invalid handler settings encountered: {0}".format(redacted_handler_settings)) def log_and_report_invalid_mdsd_cfg(self, ext_event_type, config_validate_cmd_msg, mdsd_cfg_xml): """ Report invalid result from 'mdsd -v -c xmlCfg.xml' :param ext_event_type: Type of extension event being performed (most likely 'HeartBeat') :param config_validate_cmd_msg: Output of 'mdsd -v -c xmlCfg.xml' :param mdsd_cfg_xml: Content of xmlCfg.xml to be sent to Geneva :return: None """ message = "Problem(s) detected in generated mdsd configuration. Can't enable, although this install/enable " \ "operation is reported as successful so the VM can complete successful startup. Linux Diagnostic " \ "Extension will exit. Config validation message: {0}".format(config_validate_cmd_msg) self._logger_log(message) self._status_reporter(ext_event_type, 'success', '0', message) self._waagent_event_adder(name=self._ext_name, op=ext_event_type, isSuccess=True, # Note this is True, because it is a user error. version=self._ext_ver, message="Problem(s) detected in generated mdsd configuration: {0}".format(mdsd_cfg_xml)) def read_uuid(): uuid = '' uuid_file_path = '/sys/class/dmi/id/product_uuid' try: with open(uuid_file_path) as f: uuid = f.readline().strip() except Exception as e: raise LadLoggingConfigException('read_uuid() failed: Unable to open uuid file {0}'.format(uuid_file_path)) if not uuid: raise LadLoggingConfigException('read_uuid() failed: Empty content in uuid file {0}'.format(uuid_file_path)) return uuid def encrypt_secret_with_cert(run_command, logger, cert_path, secret): """ update_account_settings() helper. :param run_command: Function to run an arbitrary command :param logger: Function to log error messages :param cert_path: Cert file path :param secret: Secret to encrypt :return: Encrypted secret string. None if openssl command exec fails. """ f = tempfile.NamedTemporaryFile(suffix='mdsd', delete=True) # Have openssl write to our temporary file (on Linux we don't have an exclusive lock on the temp file). # openssl smime, when asked to put output in a file, simply overwrites the file; it does not unlink/creat or # creat/rename. cmd = "echo -n '{0}' | openssl smime -aes256 -encrypt -outform DER -out {1} {2}" cmd_to_run = cmd.format(secret, f.name, cert_path) ret_status, ret_msg = run_command(cmd_to_run, should_log=False) if ret_status is not 0: logger("Encrypting storage secret failed with the following message: " + ret_msg) return None encrypted_secret = f.read() f.close() # Deletes the temp file return binascii.b2a_hex(encrypted_secret).upper() def tail(log_file, output_size=1024): if not os.path.exists(log_file): return "" pos = min(output_size, os.path.getsize(log_file)) with open(log_file, "r") as log: log.seek(-pos, 2) buf = log.read(output_size) buf = filter(lambda x: x in string.printable, buf) return buf.decode("ascii", "ignore") def update_selinux_settings_for_rsyslogomazuremds(run_command, ext_dir): # This is still needed for Redhat-based distros, which still require SELinux to be allowed # for even Unix domain sockets. # Anyway, we no longer use 'semanage' (so no need to install policycoreutils-python). # We instead compile from the bundled SELinux module def for lad_mdsd # TODO Either check the output of these commands or run without capturing output if os.path.exists("/usr/sbin/semodule") or os.path.exists("/sbin/semodule"): run_command('checkmodule -M -m -o {0}/lad_mdsd.mod {1}/lad_mdsd.te'.format(ext_dir, ext_dir)) run_command('semodule_package -o {0}/lad_mdsd.pp -m {1}/lad_mdsd.mod'.format(ext_dir, ext_dir)) run_command('semodule -u {0}/lad_mdsd.pp'.format(ext_dir)) def get_mdsd_proxy_config(waagent_setting, ext_settings, logger): # mdsd http proxy setting proxy_setting_name = 'mdsdHttpProxy' proxy_config = waagent_setting # waagent.HttpProxyConfigString from /etc/waagent.conf has highest priority if not proxy_config: proxy_config = ext_settings.read_protected_config(proxy_setting_name) # Protected setting has next priority if not proxy_config: proxy_config = ext_settings.read_public_config(proxy_setting_name) if not isinstance(proxy_config, basestring): logger('Error: mdsdHttpProxy config is not a string. Ignored.') else: proxy_config = proxy_config.strip() if proxy_config: logger("mdsdHttpProxy setting was given and will be passed to mdsd, " "but not logged here in case there's a password in it") return proxy_config return '' def escape_nonalphanumerics(data): return ''.join([ch if ch.isalnum() else ":{0:04X}".format(ord(ch)) for ch in data]) # TODO Should this be placed in WAAgentUtil.py? def get_deployment_id_from_hosting_env_cfg(waagent_dir, logger_log, logger_error): """ Get deployment ID from waagent dir's HostingEnvironmentConfig.xml. :param waagent_dir: Waagent dir path (/var/lib/waagent) :param logger_log: Normal logging function (hutil.log) :param logger_error: Error logging function (hutil.error) :return: Obtained deployment ID string if the hosting env cfg xml exists & deployment ID is found. "unknown" if the xml exists, but deployment ID can't be found. None if the xml does not exist. """ identity = "unknown" env_cfg_path = os.path.join(waagent_dir, "HostingEnvironmentConfig.xml") if not os.path.exists(env_cfg_path): logger_log("No Deployment ID (not running in a hosted environment") return identity try: with open(env_cfg_path, 'r') as env_cfg_file: xml_text = env_cfg_file.read() dom = xml.dom.minidom.parseString(xml_text) deployment = dom.getElementsByTagName("Deployment") name = deployment[0].getAttribute("name") if name: identity = name logger_log("Deployment ID found: {0}.".format(identity)) except Exception as e: # use fallback identity logger_error("Failed to retrieve deployment ID. Error:{0}\nStacktrace: {1}".format(e, traceback.format_exc())) return identity def write_lad_pids_to_file(pid_file_path, py_pid, mdsd_pid=None): """ Write LAD process IDs to file :param int py_pid: PID of diagnostic.py :param int mdsd_pid: PID of mdsd or None (when called before mdsd is started) :param str pid_file_path: Path of the file to be written :return: None """ with open(pid_file_path, 'w') as f: f.write(str(py_pid) + '\n') if mdsd_pid is not None: f.write(str(mdsd_pid) + '\n') def append_string_to_file(string, filepath): """ Append string content to file :param string: A str object that holds the content to be appended to the file :param filepath: Path to the file to be appended :return: None """ with open(filepath, 'a') as f: f.write(string) def read_file_to_string(filepath): """ Read entire file and return it as string. If file can't be read, return "Can't read " :param str filepath: Path of the file to read :rtype: str :return: Content of the file in a single string, or "Can't read " if file can't be read. """ try: with open(filepath) as f: return f.read() except Exception as e: return "Can't read {0}. Exception thrown: {1}".format(filepath, e) ================================================ FILE: Diagnostic/Utils/omsagent_util.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above # copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT # LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT # SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. import os import re import socket import time from Utils.misc_helpers import append_string_to_file # op is either '--upgrade' or '--remove' omsagent_universal_sh_cmd_template = 'sh omsagent-*.universal.x64.sh {op}' # args is either '-w LAD' or '-x LAD' or '-l' omsagent_lad_workspace_cmd_template = 'sh /opt/microsoft/omsagent/bin/omsadmin.sh {args}' omsagent_lad_dir = '/etc/opt/microsoft/omsagent/LAD/' def setup_omsagent_for_lad(run_command): """ Install omsagent by executing the universal shell bundle. Also onboard omsagent for LAD. :param run_command: External command execution function (e.g., RunGetOutput) :rtype: int, str :return: 2-tuple of process exit code and output (run_command's return values as is) """ # 1. Install omsagent. It's a noop if it's already installed. cmd_exit_code, cmd_output = run_command(omsagent_universal_sh_cmd_template.format(op='--upgrade')) if cmd_exit_code != 0: return 1, 'setup_omsagent_for_lad(): omsagent universal installer shell execution failed. ' \ 'Output: {0}'.format(cmd_output) # 2. Onboard to LAD workspace. Should be a noop if it's already done. if not os.path.isdir(omsagent_lad_dir): cmd_exit_code, cmd_output = run_command(omsagent_lad_workspace_cmd_template.format(args='-w LAD')) if cmd_exit_code != 0: return 2, 'setup_omsagent_for_lad(): LAD workspace onboarding failed. Output: {0}'.format(cmd_output) # All succeeded return 0, 'setup_omsagent_for_lad() succeeded' omsagent_control_cmd_template = '/opt/microsoft/omsagent/bin/service_control {op} LAD' def control_omsagent(op, run_command): """ Start/stop/restart omsagent service using omsagent service_control script. :param op: Operation type. Must be 'start', 'stop', or 'restart' :param run_command: External command execution function (e.g., RunGetOutput) :rtype: int, str :return: 2-tuple of process exit code and output (run_command's return values as is) """ cmd_exit_code, cmd_output = run_command(omsagent_control_cmd_template.format(op=op)) if cmd_exit_code != 0: return 1, 'control_omsagent({0}) failed. Output: {1}'.format(op, cmd_output) return 0, 'control_omsagent({0}) succeeded'.format(op) def tear_down_omsagent_for_lad(run_command, remove_omsagent): """ Remove omsagent by executing the universal shell bundle. Remove LAD workspace before that. Don't remove omsagent if OMSAgentForLinux extension is installed (i.e., if any other omsagent workspace exists). :param run_command: External command execution function (e.g., RunGetOutput) :param remove_omsagent: A boolean indicating whether to remove omsagent bundle or not. :rtype: int, str :return: 2-tuple of process exit code and output (run_command's return values) """ return_msg = '' # 1. Unconfigure syslog. Ignore failure (just collect failure output). cmd_exit_code, cmd_output = unconfigure_syslog(run_command) if cmd_exit_code != 0: return_msg += 'remove_omsagent_for_lad(): unconfigure_syslog() failed. ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # 2. Remove LAD workspace. Ignore failure. cmd_exit_code, cmd_output = run_command(omsagent_lad_workspace_cmd_template.format(args='-x LAD')) if cmd_exit_code != 0: return_msg += 'remove_omsagent_for_lad(): LAD workspace removal failed. ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) if remove_omsagent: # 3. Uninstall omsagent when specified. Do this only if there's no other omsagent workspace. cmd_exit_code, cmd_output = run_command(omsagent_lad_workspace_cmd_template.format(args='-l')) if cmd_output.strip().lower() == 'no workspace': cmd_exit_code, cmd_output = run_command(omsagent_universal_sh_cmd_template.format(op='--remove')) if cmd_exit_code != 0: return_msg += 'remove_omsagent_for_lad(): remove-omsagent failed. ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) else: return_msg += 'remove_omsagent_for_lad(): omsagent workspace listing failed. ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # Done return 0, return_msg if return_msg else 'remove_omsagent_for_lad() succeeded' rsyslog_top_conf_path = '/etc/rsyslog.conf' rsyslog_d_path = '/etc/rsyslog.d/' rsyslog_d_omsagent_conf_path = '/etc/rsyslog.d/95-omsagent.conf' # hard-coded by omsagent syslog_ng_conf_path = '/etc/syslog-ng/syslog-ng.conf' def is_rsyslog_installed(): """ Returns true iff rsyslog is installed on the machine. :rtype: bool :return: True if rsyslog is installed. False otherwise. """ return os.path.exists(rsyslog_top_conf_path) def is_new_rsyslog_installed(): """ Returns true iff newer version of rsyslog (that has /etc/rsyslog.d/) is installed on the machine. :rtype: bool :return: True if /etc/rsyslog.d/ exists. False otherwise. """ return os.path.exists(rsyslog_d_path) def is_syslog_ng_installed(): """ Returns true iff syslog-ng is installed on the machine. :rtype: bool :return: True if syslog-ng is installed. False otherwise. """ return os.path.exists(syslog_ng_conf_path) def get_syslog_ng_src_name(): """ Some syslog-ng distributions use different source name ("s_src" vs "src"), causing syslog-ng restarts to fail when we provide a non-existent source name. Need to search the syslog-ng.conf file and retrieve the source name as below. :rtype: str :return: syslog-ng source name retrieved from syslog-ng.conf. 'src' if none available. """ syslog_ng_src_name = 'src' try: with open(syslog_ng_conf_path, 'r') as f: syslog_ng_cfg = f.read() src_match = re.search(r'\n\s*source\s+([^\s]+)\s*{', syslog_ng_cfg) if src_match: syslog_ng_src_name = src_match.group(1) except Exception as e: pass # Ignore any errors, because the default ('src') will do. return syslog_ng_src_name def get_fluentd_syslog_src_port(): """ Returns a TCP/UDP port number that'll be supplied to the fluentd syslog src plugin (for it to listen to for syslog events from rsyslog/syslog-ng). Ports from 25224 to 25423 will be tried for bind() and the first available one will be returned. 25224 is the default port number that's picked by omsagent. This is definitely not 100% correct with potential races. The correct solution would be to let fluentd syslog src plugin bind to 0 and write the resulting bound port number to a file, so that we can get the port number from the file. However, the current fluentd in_syslog.rb doesn't write to a file, so that method won't work. And yet we still want to minimize possibility of binding to an already-in-use port, so here's a workaround. :rtype: int :return: A successfully bound (& closed) TCP/UDP port number. -1 if all failed. """ for port in range(25229, 25424): try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('', port)) s.close() return port except Exception as e: pass return -1 omsagent_config_syslog_sh_cmd_template = 'sh /opt/microsoft/omsagent/bin/configure_syslog.sh {op} LAD {port}' def run_omsagent_config_syslog_sh(run_command, op, port=''): """ Run omsagent's configure_syslog.sh script for LAD. :param run_command: External command execution function (e.g., RunGetOutput) :param op: Type of operation. Must be one of 'configure', 'unconfigure', and 'restart' :param port: TCP/UDP port number to supply as fluentd in_syslog plugin listen port :rtype: int, str :return: 2-tuple of the process exit code and the resulting output string (basically run_command's return values) """ return run_command(omsagent_config_syslog_sh_cmd_template.format(op=op, port=port)) fluentd_syslog_src_cfg_path = '/etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/syslog.conf' syslog_port_pattern_marker = '%SYSLOG_PORT%' def configure_syslog(run_command, port, in_syslog_cfg, rsyslog_cfg, syslog_ng_cfg): """ Configure rsyslog/syslog-ng and fluentd's in_syslog with the given TCP port. rsyslog/syslog-ng config is done by omsagent's configure_syslog.sh. We also try to unconfigure first, to avoid duplicate entries in the related config files. :param run_command: External command execution function (e.g., RunGetOutput) :param port: TCP/UDP port number to be used for rsyslog/syslog-ng and fluentd's in_syslog :param in_syslog_cfg: Fluentd's in_syslog config string. Should be overwritten to omsagent.d/syslog.conf :param rsyslog_cfg: rsyslog config that's generated by LAD syslog configurator, that should be appended to /etc/rsyslog.d/95-omsagent.conf or /etc/rsyslog.conf :param syslog_ng_cfg: syslog-ng config that's generated by LAD syslog configurator, that should be appended to /etc/syslog-ng/syslog-ng.conf :rtype: int, str :return: 2-tuple of the process exit code and the resulting output string (run_command's return values) """ if not is_rsyslog_installed() and not is_syslog_ng_installed(): return 0, 'configure_syslog(): Nothing to do: Neither rsyslog nor syslog-ng is installed on the system' # 1. Unconfigure existing syslog instance (if any) to avoid duplicates # Continue even if this step fails (not critical) cmd_exit_code, cmd_output = unconfigure_syslog(run_command) extra_msg = '' if cmd_exit_code != 0: extra_msg = 'configure_syslog(): configure_syslog.sh unconfigure failed (still proceeding): ' + cmd_output # 2. Configure new syslog instance with port number. # Ordering is very tricky. This must be done before modifying /etc/syslog-ng/syslog-ng.conf # or /etc/rsyslog.d/95-omsagent.conf below! cmd_exit_code, cmd_output = run_omsagent_config_syslog_sh(run_command, 'configure', port) if cmd_exit_code != 0: return 2, 'configure_syslog(): configure_syslog.sh configure failed: ' + cmd_output # 2.5. Replace '%SYSLOG_PORT%' in all passed syslog configs with the obtained port number in_syslog_cfg = in_syslog_cfg.replace(syslog_port_pattern_marker, str(port)) rsyslog_cfg = rsyslog_cfg.replace(syslog_port_pattern_marker, str(port)) syslog_ng_cfg = syslog_ng_cfg.replace(syslog_port_pattern_marker, str(port)) # 3. Configure fluentd in_syslog plugin (write the fluentd plugin config file) try: with open(fluentd_syslog_src_cfg_path, 'w') as f: f.write(in_syslog_cfg) except Exception as e: return 3, 'configure_syslog(): Writing to omsagent.d/syslog.conf failed: {0}'.format(e) # 4. Update (add facilities/levels) rsyslog or syslog-ng config try: if is_syslog_ng_installed(): append_string_to_file(syslog_ng_cfg, syslog_ng_conf_path) elif is_new_rsyslog_installed(): append_string_to_file(rsyslog_cfg, rsyslog_d_omsagent_conf_path) else: # old rsyslog, so append to rsyslog_top_conf_path append_string_to_file(rsyslog_cfg, rsyslog_top_conf_path) except Exception as e: return 4, 'configure_syslog(): Adding facilities/levels to rsyslog/syslog-ng conf failed: {0}'.format(e) # 5. Restart syslog cmd_exit_code, cmd_output = restart_syslog(run_command) if cmd_exit_code != 0: return 5, 'configure_syslog(): Failed at restarting syslog (rsyslog or syslog-ng). ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # All succeeded return 0, 'configure_syslog(): Succeeded. Extra message: {0}'.format(extra_msg if extra_msg else 'None') fluentd_tail_src_cfg_path = '/etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/tail.conf' def configure_filelog(in_tail_cfg): """ Configure fluentd's in_tail plugin for LAD file logging. :param in_tail_cfg: Fluentd's in_tail plugin cfg for LAD filelog setting (obtained from LadConfigAll obj) :rtype: str, int :return: A 2-tuple of process exit code and output """ # Just needs to write to the omsagent.d/tail.conf file try: with open(fluentd_tail_src_cfg_path, 'w') as f: f.write(in_tail_cfg) except Exception as e: return 1, 'configure_filelog(): Failed writing fluentd in_tail config file' return 0, 'configure_filelog(): Succeeded writing fluentd in_tail config file' fluentd_out_mdsd_cfg_path = '/etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/z_out_mdsd.conf' def configure_out_mdsd(out_mdsd_cfg): """ Configure fluentd's out_mdsd plugin for LAD file logging. :param out_mdsd_cfg: Fluentd's out_mdsd plugin cfg for the entire LAD setting (obtained from LadConfigAll obj) :rtype: str, int :return: A 2-tuple of process exit code and output """ # Just needs to write to the omsagent.d/tail.conf file try: with open(fluentd_out_mdsd_cfg_path, 'w') as f: f.write(out_mdsd_cfg) except Exception as e: return 1, 'configure_out_mdsd(): Failed writing fluentd out_mdsd config file' return 0, 'configure_out_mdsd(): Succeeded writing fluentd out_mdsd config file' def unconfigure_syslog(run_command): """ Unconfigure rsyslog/syslog-ng and fluentd's in_syslog for LAD. rsyslog/syslog-ng unconfig is done by omsagent's configure_syslog.sh. :param run_command: External command execution function (e.g., RunGetOutput) :rtype: int, str :return: 2-tuple of the process exit code and the resulting output string (run_command's return values) """ # 1. Find the port number in fluentd's in_syslog conf.. if not os.path.isfile(fluentd_syslog_src_cfg_path): return 0, "unconfigure_syslog(): Nothing to unconfigure: omsagent fluentd's in_syslog is not configured" # 2. Read fluentd's in_syslog config try: with open(fluentd_syslog_src_cfg_path) as f: fluentd_syslog_src_cfg = f.read() except Exception as e: return 1, "unconfigure_syslog(): Failed reading fluentd's in_syslog config: {0}".format(e) # 3. Extract the port number and run omsagent's configure_syslog.sh to unconfigure port_match = re.search(r'port\s+(\d+)', fluentd_syslog_src_cfg) if not port_match: return 2, 'unconfigure_syslog(): Invalid fluentd in_syslog config: port number setting not found' port = int(port_match.group(1)) cmd_exit_code, cmd_output = run_omsagent_config_syslog_sh(run_command, 'unconfigure', port) if cmd_exit_code != 0: return 3, 'unconfigure_syslog(): configure_syslog.sh failed: ' + cmd_output # 4. Remove fluentd's in_syslog conf file try: os.remove(fluentd_syslog_src_cfg_path) except Exception as e: return 4, 'unconfigure_syslog(): Removing omsagent.d/syslog.conf failed: {0}'.format(e) #5. All succeeded return 0, 'unconfigure_syslog(): Succeeded' def restart_syslog(run_command): """ Restart rsyslog/syslog-ng (so that any new config will be applied) :param run_command: External command execution function (e.g., RunGetOutput) :rtype: int, str :return: 2-tuple of the process exit code and the resulting output string (run_command's return values) """ return run_omsagent_config_syslog_sh(run_command, 'restart') # port param is dummy here. def restart_omiserver(run_command): """ Restart omiserver as needed (it crashes sometimes, and doesn't restart automatically yet) :param run_command: External command execution function (e.g., RunGetOutput) :rtype: int, str :return: 2-tuple of the process exit code and the resulting output string (run_command's return values) """ return run_command('/opt/omi/bin/service_control restart') def setup_omsagent(configurator, run_command, logger_log, logger_error): """ Set up omsagent. Install necessary components, configure them as needed, and start the agent. :param configurator: A LadConfigAll object that's obtained from a valid LAD JSON settings config. This is needed to retrieve the syslog (rsyslog/syslog-ng) and the fluentd configs. :param run_command: External command executor (e.g., RunGetOutput) :param logger_log: Logger for normal logging messages (e.g., hutil.log) :param logger_error: Logger for error loggin messages (e.g., hutil.error) :return: Pair of status code and message. 0 status code for success. Non-zero status code for a failure and the associated failure message. """ # Remember whether OMI (not omsagent) needs to be freshly installed. # This is needed later to determine whether to reconfigure the omiserver.conf or not for security purpose. need_fresh_install_omi = not os.path.exists('/opt/omi/bin/omiserver') logger_log("Begin omsagent setup.") # 1. Install omsagent, onboard to LAD workspace # We now try to install/setup all the time. If it's already installed. Any additional install is a no-op. is_omsagent_setup_correctly = False maxTries = 5 # Try up to 5 times to install omsagent for trialNum in range(1, maxTries + 1): cmd_exit_code, cmd_output = setup_omsagent_for_lad(run_command) if cmd_exit_code == 0: # Successfully set up is_omsagent_setup_correctly = True break logger_error("omsagent setup failed (trial #" + str(trialNum) + ").") if trialNum < maxTries: logger_error("Retrying in 30 seconds...") time.sleep(30) if not is_omsagent_setup_correctly: logger_error("omsagent setup failed " + str(maxTries) + " times. Giving up...") return 1, "omsagent setup failed {0} times. " \ "Last exit code={1}, Output={2}".format(maxTries, cmd_exit_code, cmd_output) # Issue #265. OMI httpsport shouldn't be reconfigured when LAD is re-enabled or just upgraded. # In other words, OMI httpsport config should be updated only on a fresh OMI install. if need_fresh_install_omi: # Check if OMI is configured to listen to any non-zero port and reconfigure if so. omi_listens_to_nonzero_port = run_command(r"grep '^\s*httpsport\s*=' /etc/opt/omi/conf/omiserver.conf " r"| grep -v '^\s*httpsport\s*=\s*0\s*$'")[0] is 0 if omi_listens_to_nonzero_port: run_command("/opt/omi/bin/omiconfigeditor httpsport -s 0 < /etc/opt/omi/conf/omiserver.conf " "> /etc/opt/omi/conf/omiserver.conf_temp") run_command("mv /etc/opt/omi/conf/omiserver.conf_temp /etc/opt/omi/conf/omiserver.conf") # 2. Configure all fluentd plugins (in_syslog, in_tail, out_mdsd) # 2.1. First get a free TCP/UDP port for fluentd in_syslog plugin. port = get_fluentd_syslog_src_port() if port < 0: return 3, 'setup_omsagent(): Failed at getting a free TCP/UDP port for fluentd in_syslog' # 2.2. Configure syslog cmd_exit_code, cmd_output = configure_syslog(run_command, port, configurator.get_fluentd_syslog_src_config(), configurator.get_rsyslog_config(), configurator.get_syslog_ng_config()) if cmd_exit_code != 0: return 4, 'setup_omsagent(): Failed at configuring in_syslog. Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # 2.3. Configure filelog cmd_exit_code, cmd_output = configure_filelog(configurator.get_fluentd_tail_src_config()) if cmd_exit_code != 0: return 5, 'setup_omsagent(): Failed at configuring in_tail. Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # 2.4. Configure out_mdsd cmd_exit_code, cmd_output = configure_out_mdsd(configurator.get_fluentd_out_mdsd_config()) if cmd_exit_code != 0: return 6, 'setup_omsagent(): Failed at configuring out_mdsd. Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # 3. Restart omsagent cmd_exit_code, cmd_output = control_omsagent('restart', run_command) if cmd_exit_code != 0: return 8, 'setup_omsagent(): Failed at restarting omsagent (fluentd). ' \ 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) # All done... return 0, "setup_omsagent(): Succeeded" ================================================ FILE: Diagnostic/__init__.py ================================================ ================================================ FILE: Diagnostic/decrypt_protected_settings.sh ================================================ #!/bin/bash # A shell script utility to decrypt the extension's protected settings for debugging purpose # Must be run at /var/lib/waagent/Microsoft.Azure.Diagnostics.LinuxDiagnostic-.../ # with the settings file path (e.g., config/0.settings) as the only cmdline arg if [ $# -lt 1 ]; then echo "Usage: $0 " exit 1 fi thumbprint=$(jq -r '.runtimeSettings[].handlerSettings.protectedSettingsCertThumbprint' $1) jq -r '.runtimeSettings[].handlerSettings.protectedSettings' $1 | base64 --decode | openssl smime -inform DER -decrypt -recip ../$thumbprint.crt -inkey ../$thumbprint.prv | jq . ================================================ FILE: Diagnostic/diagnostic.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the ""Software""), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following # conditions: The above copyright notice and this permission notice shall be included in all copies or substantial # portions of the Software. THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import datetime import exceptions import os.path import platform import signal import subprocess import sys import syslog import threading import time import traceback import xml.etree.ElementTree as ET import json # Just wanted to be able to run 'python diagnostic.py ...' from a local dev box where there's no waagent. # Actually waagent import can succeed even on a Linux machine without waagent installed, # by setting PYTHONPATH env var to the azure-linux-extensions/Common/WALinuxAgent-2.0.16, # but let's just keep this try-except here on them for any potential local imports that may throw. try: # waagent, ext handler from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util # Old LAD utils import Utils.LadDiagnosticUtil as LadUtil import Utils.XmlUtil as XmlUtil # New LAD utils import DistroSpecific import watcherutil from Utils.lad_ext_settings import LadExtSettings from Utils.misc_helpers import * import lad_config_all as lad_cfg from Utils.imds_util import ImdsLogger import Utils.omsagent_util as oms import telegraf_utils.telegraf_config_handler as telhandler import metrics_ext_utils.metrics_ext_handler as me_handler import metrics_ext_utils.metrics_constants as metrics_constants except Exception as e: print('A local import (e.g., waagent) failed. Exception: {0}\nStacktrace: {1}'.format(e, traceback.format_exc())) print("Can't proceed. Exiting with a special exit code 119.") sys.exit(119) # This is the only thing we can do, as all logging depends on waagent/hutil. # Globals declaration/initialization (with const values only) for IDE g_ext_settings = None # LAD extension settings object g_lad_log_helper = None # LAD logging helper object g_dist_config = None # Distro config object g_ext_dir = '' # Extension directory (e.g., /var/lib/waagent/Microsoft.OSTCExtensions.LinuxDiagnostic-x.y.zzzz) g_mdsd_file_resources_dir = '/var/run/mdsd' g_mdsd_role_name = 'lad_mdsd' # Different mdsd role name for multiple mdsd process instances g_mdsd_file_resources_prefix = '' # Eventually '/var/run/mdsd/lad_mdsd' g_lad_pids_filepath = '' # LAD process IDs (diagnostic.py, mdsd) file path. g_ext_dir + '/lad.pids' g_ext_op_type = None # Extension operation type (e.g., Install, Enable, HeartBeat, ...) g_mdsd_bin_path = '/usr/local/lad/bin/mdsd' # mdsd binary path. Fixed w/ lad-mdsd-*.{deb,rpm} pkgs g_diagnostic_py_filepath = '' # Full path of this script. g_ext_dir + '/diagnostic.py' # Only 2 globals not following 'g_...' naming convention, for legacy readability... RunGetOutput = None # External command executor callable hutil = None # Handler util object enable_metrics_ext = False #Flag to enable/disable MetricsExtension enable_telegraf = False #Flag to enable/disable Telegraf me_msi_token_expiry_epoch = None def init_distro_specific_actions(): """ Identify the specific Linux distribution in use. Set the global distConfig to point to the corresponding implementation class. If the distribution isn't supported, set the extension status appropriately and exit. Expects the global hutil to already be initialized. """ # TODO Exit immediately if distro is unknown global g_dist_config, RunGetOutput dist = platform.dist() name = '' version = '' try: if dist[0] != '': name = dist[0] version = dist[1] else: try: # platform.dist() in python 2.7.15 does not recognize SLES/OpenSUSE 15. with open("/etc/os-release", "r") as fp: for line in fp: if line.startswith("ID="): name = line.split("=")[1] name = name.split("-")[0] name = name.replace("\"", "").replace("\n", "") elif line.startswith("VERSION_ID="): version = line.split("=")[1] version = version.split(".")[0] version = version.replace("\"", "").replace("\n", "") except: raise hutil.log("os version: {0}:{1}".format(name.lower(), version)) g_dist_config = DistroSpecific.get_distro_actions(name.lower(), version, hutil.log) RunGetOutput = g_dist_config.log_run_get_output except exceptions.LookupError as ex: hutil.error("os version: {0}:{1} not supported".format(dist[0], dist[1])) # TODO Exit immediately if distro is unknown. This is currently done in main(). g_dist_config = None def init_extension_settings(): """Initialize extension's public & private settings. hutil must be already initialized prior to calling this.""" global g_ext_settings # Need to read/parse the Json extension settings (context) first. hutil.try_parse_context() hutil.set_verbose_log(False) # This is default, but this choice will be made explicit and logged. g_ext_settings = LadExtSettings(hutil.get_handler_settings()) def init_globals(): """Initialize all the globals in a function so that we can catch any exceptions that might be raised.""" global hutil, g_ext_dir, g_mdsd_file_resources_prefix, g_lad_pids_filepath global g_diagnostic_py_filepath, g_lad_log_helper waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("LinuxDiagnostic started to handle.") hutil = Util.HandlerUtility(waagent.Log, waagent.Error) init_extension_settings() init_distro_specific_actions() g_ext_dir = os.getcwd() g_mdsd_file_resources_prefix = os.path.join(g_mdsd_file_resources_dir, g_mdsd_role_name) g_lad_pids_filepath = os.path.join(g_ext_dir, 'lad.pids') g_diagnostic_py_filepath = os.path.join(os.getcwd(), __file__) g_lad_log_helper = LadLogHelper(hutil.log, hutil.error, waagent.AddExtensionEvent, hutil.do_status_report, hutil.get_name(), hutil.get_extension_version()) def setup_dependencies_and_mdsd(configurator): """ Set up dependencies for mdsd, such as following: 1) Distro-specific packages (see DistroSpecific.py) 2) Set up omsagent (fluentd), syslog (rsyslog or syslog-ng) for mdsd :return: Status code and message """ install_package_error = "" retry = 3 while retry > 0: error, msg = g_dist_config.install_required_packages() hutil.log(msg) if error == 0: break else: retry -= 1 hutil.log("Sleep 60 retry " + str(retry)) install_package_error = msg time.sleep(60) if install_package_error: if len(install_package_error) > 1024: install_package_error = install_package_error[0:512] + install_package_error[-512:-1] hutil.error(install_package_error) return 2, install_package_error # Run mdsd prep commands g_dist_config.prepare_for_mdsd_install() # Set up omsagent omsagent_setup_exit_code, omsagent_setup_output = oms.setup_omsagent(configurator, RunGetOutput, hutil.log, hutil.error) if omsagent_setup_exit_code is not 0: return 3, omsagent_setup_output # Install lad-mdsd pkg (/usr/local/lad/bin/mdsd). Must be done after omsagent install because of dependencies cmd_exit_code, cmd_output = g_dist_config.install_lad_mdsd() if cmd_exit_code != 0: return 4, 'lad-mdsd pkg install failed. Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output) return 0, 'success' def install_lad_as_systemd_service(): """ Install LAD as a systemd service on systemd-enabled distros/versions (e.g., Ubuntu 16.04) :return: None """ RunGetOutput('sed s#{WORKDIR}#' + g_ext_dir + '# ' + g_ext_dir + '/services/mdsd-lde.service > /lib/systemd/system/mdsd-lde.service') RunGetOutput('systemctl daemon-reload') def create_core_components_configs(): """ Entry point to creating all configs of LAD's core components (mdsd, omsagent, rsyslog/syslog-ng, ...). This function shouldn't be called on Install/Enable. Only Daemon op needs to call this. :rtype: LadConfigAll :return: A valid LadConfigAll object if config is valid. None otherwise. """ deployment_id = get_deployment_id_from_hosting_env_cfg(waagent.LibDir, hutil.log, hutil.error) # Define wrappers around a couple misc_helpers. These can easily be mocked out in tests. PEP-8 says use # def, don't assign a lambda to a variable. *shrug* def encrypt_string(cert, secret): return encrypt_secret_with_cert(RunGetOutput, hutil.error, cert, secret) configurator = lad_cfg.LadConfigAll(g_ext_settings, g_ext_dir, waagent.LibDir, deployment_id, read_uuid, encrypt_string, hutil.log, hutil.error) try: config_valid, config_invalid_reason = configurator.generate_all_configs() except Exception as e: config_invalid_reason =\ 'Exception while generating configs: {0}. Traceback: {1}'.format(e, traceback.format_exc()) hutil.error(config_invalid_reason) config_valid = False if not config_valid: g_lad_log_helper.log_and_report_failed_config_generation( g_ext_op_type, config_invalid_reason, g_ext_settings.redacted_handler_settings()) return None global enable_metrics_ext global enable_telegraf ladconfig = configurator._ladCfg() # verify metrics extension should be enabled sink = configurator._sink_configs_public.get_sink_by_name("AzMonSink") if sink is not None: if sink['name'] == 'AzMonSink': enable_metrics_ext = True # verify telegraf should be enabled (either metrics intervals or performance counters configured) metrics_intervals = LadUtil.getAggregationPeriodsFromLadCfg(ladconfig) perf_counter_config = LadUtil.getDiagnosticsMonitorConfigurationElement(ladconfig, 'performanceCounters') if ((metrics_intervals != []) or (perf_counter_config)): enable_telegraf = True return configurator def check_for_supported_waagent_and_distro_version(): """ Checks & returns if the installed waagent and the Linux distro/version are supported by this LAD. :rtype: bool :return: True iff so. """ for notsupport in ('WALinuxAgent-2.0.5', 'WALinuxAgent-2.0.4', 'WALinuxAgent-1'): code, str_ret = RunGetOutput("grep 'GuestAgentVersion.*" + notsupport + "' /usr/sbin/waagent", should_log=False) if code == 0 and str_ret.find(notsupport) > -1: hutil.log("cannot run this extension on " + notsupport) hutil.do_status_report(g_ext_op_type, "error", '1', "cannot run this extension on " + notsupport) return False if g_dist_config is None: msg = ("LAD does not support distro/version ({0}); not installed. This extension install/enable operation is " "still considered a success as it's an external error.").format(str(platform.dist())) hutil.log(msg) hutil.do_status_report(g_ext_op_type, "success", '0', msg) waagent.AddExtensionEvent(name=hutil.get_name(), op=g_ext_op_type, isSuccess=True, version=hutil.get_extension_version(), message="Can't be installed on this OS " + str(platform.dist())) return False return True def main(command): init_globals() global g_ext_op_type global me_msi_token_expiry_epoch g_ext_op_type = get_extension_operation_type(command) waagent_ext_event_type = wala_event_type_for_telemetry(g_ext_op_type) if not check_for_supported_waagent_and_distro_version(): return try: hutil.log("Dispatching command:" + command) if g_ext_op_type is waagent.WALAEventOperation.Disable: if g_dist_config.use_systemd(): RunGetOutput('systemctl stop mdsd-lde && systemctl disable mdsd-lde') else: stop_mdsd() oms.tear_down_omsagent_for_lad(RunGetOutput, False) #Stop the telegraf and ME services tel_out, tel_msg = telhandler.stop_telegraf_service(is_lad=True) if tel_out: hutil.log(tel_msg) else: hutil.error(tel_msg) me_out, me_msg = me_handler.stop_metrics_service(is_lad=True) if me_out: hutil.log(me_msg) else: hutil.error(me_msg) hutil.do_status_report(g_ext_op_type, "success", '0', "Disable succeeded") elif g_ext_op_type is waagent.WALAEventOperation.Uninstall: if g_dist_config.use_systemd(): RunGetOutput('systemctl stop mdsd-lde && systemctl disable mdsd-lde ' + '&& rm /lib/systemd/system/mdsd-lde.service') else: stop_mdsd() # Must remove lad-mdsd package first because of the dependencies cmd_exit_code, cmd_output = g_dist_config.remove_lad_mdsd() if cmd_exit_code != 0: hutil.error('lad-mdsd remove failed. Still proceeding to uninstall. ' 'Exit code={0}, Output={1}'.format(cmd_exit_code, cmd_output)) oms.tear_down_omsagent_for_lad(RunGetOutput, True) #Delete the telegraf and ME services tel_rm_out, tel_rm_msg = telhandler.remove_telegraf_service(is_lad=True) if tel_rm_out: hutil.log(tel_rm_msg) else: hutil.error(tel_rm_msg) me_rm_out, me_rm_msg = me_handler.remove_metrics_service(is_lad=True) if me_rm_out: hutil.log(me_rm_msg) else: hutil.error(me_rm_msg) hutil.do_status_report(g_ext_op_type, "success", '0', "Uninstall succeeded") elif g_ext_op_type is waagent.WALAEventOperation.Install: # Install dependencies (omsagent, which includes omi, scx). configurator = create_core_components_configs() dependencies_err, dependencies_msg = setup_dependencies_and_mdsd(configurator) if dependencies_err != 0: g_lad_log_helper.report_mdsd_dependency_setup_failure(waagent_ext_event_type, dependencies_msg) hutil.do_status_report(g_ext_op_type, "error", '-1', "Install failed") return if g_dist_config.use_systemd(): install_lad_as_systemd_service() hutil.do_status_report(g_ext_op_type, "success", '0', "Install succeeded") elif g_ext_op_type is waagent.WALAEventOperation.Enable: if hutil.is_current_config_seq_greater_inused(): configurator = create_core_components_configs() dependencies_err, dependencies_msg = setup_dependencies_and_mdsd(configurator) if dependencies_err != 0: g_lad_log_helper.report_mdsd_dependency_setup_failure(waagent_ext_event_type, dependencies_msg) hutil.do_status_report(g_ext_op_type, "error", '-1', "Enabled failed") return # Start the Telegraf and ME services on enable after installation is complete start_telegraf_res, log_messages = telhandler.start_telegraf(is_lad=True) if start_telegraf_res: hutil.log("Successfully started metrics-sourcer.") else: hutil.error(log_messages) if enable_metrics_ext: # Generate/regenerate MSI Token required by ME generate_token = False me_token_path = g_ext_dir + "/metrics_configs/AuthToken-MSI.json" if me_msi_token_expiry_epoch is None or me_msi_token_expiry_epoch == "": if os.path.isfile(me_token_path): with open(me_token_path, "r") as f: authtoken_content = f.read() if authtoken_content and "expires_on" in authtoken_content: me_msi_token_expiry_epoch = authtoken_content["expires_on"] else: generate_token = True else: generate_token = True if me_msi_token_expiry_epoch: currentTime = datetime.datetime.now() token_expiry_time = datetime.datetime.fromtimestamp(me_msi_token_expiry_epoch) if token_expiry_time - currentTime < datetime.timedelta(minutes=30): # The MSI Token will expire within 30 minutes. We need to refresh the token generate_token = True if generate_token: generate_token = False msi_token_generated, me_msi_token_expiry_epoch, log_messages = me_handler.generate_MSI_token() if msi_token_generated: hutil.log("Successfully refreshed metrics-extension MSI Auth token.") else: hutil.error(log_messages) start_metrics_out, log_messages = me_handler.start_metrics(is_lad=True) if start_metrics_out: hutil.log("Successfully started metrics-extension.") else: hutil.error(log_messages) if g_dist_config.use_systemd(): install_lad_as_systemd_service() RunGetOutput('systemctl enable mdsd-lde') mdsd_lde_active = RunGetOutput('systemctl status mdsd-lde')[0] is 0 if not mdsd_lde_active or hutil.is_current_config_seq_greater_inused(): RunGetOutput('systemctl restart mdsd-lde') else: # if daemon process not runs lad_pids = get_lad_pids() hutil.log("get pids:" + str(lad_pids)) if len(lad_pids) != 2 or hutil.is_current_config_seq_greater_inused(): stop_mdsd() start_daemon() hutil.set_inused_config_seq(hutil.get_seq_no()) hutil.do_status_report(g_ext_op_type, "success", '0', "Enable succeeded, extension daemon started") # If the -daemon detects a problem, e.g. bad configuration, it will overwrite this status with a more # informative one. If it succeeds, all is well. elif g_ext_op_type is "Daemon": configurator = create_core_components_configs() if configurator: start_mdsd(configurator) elif g_ext_op_type is waagent.WALAEventOperation.Update: hutil.do_status_report(g_ext_op_type, "success", '0', "Update succeeded") except Exception as e: hutil.error("Failed to perform extension operation {0} with error:{1}, {2}".format(g_ext_op_type, e, traceback.format_exc())) hutil.do_status_report(g_ext_op_type, 'error', '0', 'Extension operation {0} failed:{1}'.format(g_ext_op_type, e)) def start_daemon(): """ Start diagnostic.py as a daemon for scheduled tasks and to monitor mdsd daemon. If Popen() has a problem it will raise an exception (often OSError) :return: None """ args = ['python2', g_diagnostic_py_filepath, "-daemon"] log = open(os.path.join(os.getcwd(), 'daemon.log'), 'w') hutil.log('start daemon ' + str(args)) subprocess.Popen(args, stdout=log, stderr=log) def start_watcher_thread(): """ Start watcher thread that performs periodic monitoring activities (other than mdsd) :return: None """ # Create monitor object that encapsulates monitoring activities watcher = watcherutil.Watcher(hutil.error, hutil.log, log_to_console=True) # Create an IMDS data logger and set it to the monitor object imds_logger = ImdsLogger(hutil.get_name(), hutil.get_extension_version(), waagent.WALAEventOperation.HeartBeat, waagent.AddExtensionEvent) watcher.set_imds_logger(imds_logger) # Start a thread to perform periodic monitoring activity (e.g., /etc/fstab watcher, IMDS data logging) thread_obj = threading.Thread(target=watcher.watch) thread_obj.daemon = True thread_obj.start() def start_mdsd(configurator): """ Start mdsd and monitor its activities. Report if it crashes or emits error logs. :param configurator: A valid LadConfigAll object that was obtained by create_core_components_config(). This will be used for configuring rsyslog/syslog-ng/fluentd/in_syslog/out_mdsd components :return: None """ # This must be done first, so that extension enable completion doesn't get delayed. write_lad_pids_to_file(g_lad_pids_filepath, os.getpid()) # Need 'HeartBeat' instead of 'Daemon' waagent_ext_event_type = wala_event_type_for_telemetry(g_ext_op_type) # mdsd http proxy setting proxy_config = get_mdsd_proxy_config(waagent.HttpProxyConfigString, g_ext_settings, hutil.log) if proxy_config: # Add MDSD_http_proxy to current environment. Child processes will inherit its value. os.environ['MDSD_http_proxy'] = proxy_config copy_env = os.environ.copy() # Add MDSD_CONFIG_DIR as an env variable since new mdsd master branch LAD doesnt create this dir mdsd_config_cache_dir = os.path.join(g_ext_dir, "config") copy_env["MDSD_CONFIG_DIR"] = mdsd_config_cache_dir # We then validate the mdsd config and proceed only when it succeeds. xml_file = os.path.join(g_ext_dir, 'xmlCfg.xml') tmp_env_dict = {} # Need to get the additionally needed env vars (SSL_CERT_*) for this mdsd run as well... g_dist_config.extend_environment(tmp_env_dict) added_env_str = ' '.join('{0}={1}'.format(k, tmp_env_dict[k]) for k in tmp_env_dict) config_validate_cmd = '{0}{1}{2} -v -c {3} -r {4}'.format(added_env_str, ' ' if added_env_str else '', g_mdsd_bin_path, xml_file, g_ext_dir) config_validate_cmd_status, config_validate_cmd_msg = RunGetOutput(config_validate_cmd) if config_validate_cmd_status is not 0: # Invalid config. Log error and report success. g_lad_log_helper.log_and_report_invalid_mdsd_cfg(g_ext_op_type, config_validate_cmd_msg, read_file_to_string(xml_file)) return # Start OMI if it's not running. # This shouldn't happen, but this measure is put in place just in case (e.g., Ubuntu 16.04 systemd). # Don't check if starting succeeded, as it'll be done in the loop below anyway. omi_running = RunGetOutput("/opt/omi/bin/service_control is-running", should_log=False)[0] is 1 if not omi_running: hutil.log("OMI is not running. Restarting it.") RunGetOutput("/opt/omi/bin/service_control restart") log_dir = hutil.get_log_dir() err_file_path = os.path.join(log_dir, 'mdsd.err') info_file_path = os.path.join(log_dir, 'mdsd.info') warn_file_path = os.path.join(log_dir, 'mdsd.warn') qos_file_path = os.path.join(log_dir, 'mdsd.qos') # Need to provide EH events and Rsyslog spool path since the new mdsd master branch LAD doesnt create the directory needed eh_spool_path = os.path.join(log_dir, 'eh') update_selinux_settings_for_rsyslogomazuremds(RunGetOutput, g_ext_dir) mdsd_stdout_redirect_path = os.path.join(g_ext_dir, "mdsd.log") mdsd_stdout_stream = None g_dist_config.extend_environment(copy_env) # Now prepare actual mdsd cmdline. command = '{0} -A -C -c {1} -R -r {2} -e {3} -w {4} -q {8} -S {7} -o {5}{6}'.format( g_mdsd_bin_path, xml_file, g_mdsd_role_name, err_file_path, warn_file_path, info_file_path, g_ext_settings.get_mdsd_trace_option(), eh_spool_path, qos_file_path).split(" ") try: start_watcher_thread() num_quick_consecutive_crashes = 0 mdsd_crash_msg = '' while num_quick_consecutive_crashes < 3: # We consider only quick & consecutive crashes for retries RunGetOutput('rm -f ' + g_mdsd_file_resources_prefix + '.pidport') # Must delete any existing port num file mdsd_stdout_stream = open(mdsd_stdout_redirect_path, "w") hutil.log("Start mdsd " + str(command)) mdsd = subprocess.Popen(command, cwd=g_ext_dir, stdout=mdsd_stdout_stream, stderr=mdsd_stdout_stream, env=copy_env) write_lad_pids_to_file(g_lad_pids_filepath, os.getpid(), mdsd.pid) last_mdsd_start_time = datetime.datetime.now() last_error_time = last_mdsd_start_time omi_installed = True # Remembers if OMI is installed at each iteration telegraf_restart_retries = 0 me_restart_retries = 0 max_restart_retries = 10 # Continuously monitors mdsd process while True: time.sleep(30) if " ".join(get_lad_pids()).find(str(mdsd.pid)) < 0 and len(get_lad_pids()) >= 2: mdsd.kill() hutil.log("Another process is started, now exit") return if mdsd.poll() is not None: # if mdsd has terminated time.sleep(60) mdsd_stdout_stream.flush() break # mdsd is now up for at least 30 seconds. Do some monitoring activities. # 1. Mitigate if memory leak is suspected. mdsd_memory_leak_suspected, mdsd_memory_usage_in_KB = check_suspected_memory_leak(mdsd.pid, hutil.error) if mdsd_memory_leak_suspected: g_lad_log_helper.log_suspected_memory_leak_and_kill_mdsd(mdsd_memory_usage_in_KB, mdsd, waagent_ext_event_type) break # 2. Restart OMI if it crashed (Issue #128) omi_installed = restart_omi_if_crashed(omi_installed, mdsd) # 3. Check if there's any new logs in mdsd.err and report last_error_time = report_new_mdsd_errors(err_file_path, last_error_time) # 4. Check if telegraf is running, if not, then restart if enable_telegraf and not telhandler.is_running(is_lad=True): if telegraf_restart_retries < max_restart_retries: telegraf_restart_retries += 1 hutil.log("Telegraf binary process is not running. Restarting telegraf now. Retry count - {0}".format(telegraf_restart_retries)) tel_out, tel_msg = telhandler.stop_telegraf_service(is_lad=True) if tel_out: hutil.log(tel_msg) else: hutil.error(tel_msg) start_telegraf_res, log_messages = telhandler.start_telegraf(is_lad=True) if start_telegraf_res: hutil.log("Successfully started metrics-sourcer.") else: hutil.error(log_messages) else: hutil.error("Telegraf binary process is not running. Failed to restart after {0} retries. Please check telegraf.log at {1}".format(max_restart_retries, log_dir)) else: telegraf_restart_retries = 0 # 5. Check if ME is running, if not, then restart if enable_metrics_ext: if not me_handler.is_running(is_lad=True): if me_restart_retries < max_restart_retries: me_restart_retries += 1 hutil.log("MetricsExtension binary process is not running. Restarting MetricsExtension now. Retry count - {0}".format(me_restart_retries)) me_out, me_msg = me_handler.stop_metrics_service(is_lad=True) if me_out: hutil.log(me_msg) else: hutil.error(me_msg) start_metrics_out, log_messages = me_handler.start_metrics(is_lad=True) if start_metrics_out: hutil.log("Successfully started metrics-extension.") else: hutil.error(log_messages) else: hutil.error("MetricsExtension binary process is not running. Failed to restart after {0} retries. Please check /var/log/syslog for ME logs".format(max_restart_retries)) else: me_restart_retries = 0 # 6. Regenerate the MSI auth token required for ME if it is nearing expiration # Generate/regenerate MSI Token required by ME global me_msi_token_expiry_epoch generate_token = False me_token_path = g_ext_dir + "/config/metrics_configs/AuthToken-MSI.json" if me_msi_token_expiry_epoch is None or me_msi_token_expiry_epoch == "": if os.path.isfile(me_token_path): with open(me_token_path, "r") as f: authtoken_content = json.loads(f.read()) if authtoken_content and "expires_on" in authtoken_content: me_msi_token_expiry_epoch = authtoken_content["expires_on"] else: generate_token = True else: generate_token = True if me_msi_token_expiry_epoch: currentTime = datetime.datetime.now() token_expiry_time = datetime.datetime.fromtimestamp(float(me_msi_token_expiry_epoch)) if token_expiry_time - currentTime < datetime.timedelta(minutes=30): # The MSI Token will expire within 30 minutes. We need to refresh the token generate_token = True if generate_token: generate_token = False msi_token_generated, me_msi_token_expiry_epoch, log_messages = me_handler.generate_MSI_token() if msi_token_generated: hutil.log("Successfully refreshed metrics-extension MSI Auth token.") else: hutil.error(log_messages) # Out of the inner while loop: mdsd terminated. if mdsd_stdout_stream: mdsd_stdout_stream.close() mdsd_stdout_stream = None # Check if this is NOT a quick crash -- we consider a crash quick # if it's within 30 minutes from the start time. If it's not quick, # we just continue by restarting mdsd. mdsd_up_time = datetime.datetime.now() - last_mdsd_start_time if mdsd_up_time > datetime.timedelta(minutes=30): mdsd_terminated_msg = "MDSD terminated after " + str(mdsd_up_time) + ". "\ + tail(mdsd_stdout_redirect_path) + tail(err_file_path) hutil.log(mdsd_terminated_msg) num_quick_consecutive_crashes = 0 continue # It's a quick crash. Log error and add an extension event. num_quick_consecutive_crashes += 1 mdsd_crash_msg = "MDSD crash(uptime=" + str(mdsd_up_time) + "):" + tail(mdsd_stdout_redirect_path) + tail(err_file_path) hutil.error("MDSD crashed:" + mdsd_crash_msg) # mdsd all 3 allowed quick/consecutive crashes exhausted hutil.do_status_report(waagent_ext_event_type, "error", '1', "mdsd stopped: " + mdsd_crash_msg) # Need to tear down omsagent setup for LAD before returning/exiting if it was set up earlier oms.tear_down_omsagent_for_lad(RunGetOutput, False) try: waagent.AddExtensionEvent(name=hutil.get_name(), op=waagent_ext_event_type, isSuccess=False, version=hutil.get_extension_version(), message=mdsd_crash_msg) except Exception: pass except Exception as e: if mdsd_stdout_stream: hutil.error("Error :" + tail(mdsd_stdout_redirect_path)) errmsg = "Failed to launch mdsd with error: {0}, traceback: {1}".format(e, traceback.format_exc()) hutil.error(errmsg) hutil.do_status_report(waagent_ext_event_type, 'error', '1', errmsg) waagent.AddExtensionEvent(name=hutil.get_name(), op=waagent_ext_event_type, isSuccess=False, version=hutil.get_extension_version(), message=errmsg) finally: if mdsd_stdout_stream: mdsd_stdout_stream.close() def report_new_mdsd_errors(err_file_path, last_error_time): """ Monitors if there's any new stuff in mdsd.err and report it if any through the agent/ext status report mechanism. :param err_file_path: Path of the mdsd.err file :param last_error_time: Time when last error was reported. :return: Time when the last error was reported. Same as the argument if there's no error reported in this call. A new time (error file ctime) if a new error is reported. """ if not os.path.exists(err_file_path): return last_error_time err_file_ctime = datetime.datetime.strptime(time.ctime(int(os.path.getctime(err_file_path))), "%a %b %d %H:%M:%S %Y") if last_error_time >= err_file_ctime: return last_error_time # No new error above. A new error below. last_error_time = err_file_ctime last_error = tail(err_file_path) if len(last_error) > 0 and (datetime.datetime.now() - last_error_time) < datetime.timedelta(minutes=30): # Only recent error logs (within 30 minutes) are reported. hutil.log("Error in MDSD:" + last_error) hutil.do_status_report(g_ext_op_type, "success", '1', "message in mdsd.err:" + str(last_error_time) + ":" + last_error) return last_error_time def stop_mdsd(): """ Stop mdsd process :return: None """ pids = get_lad_pids() if not pids: return 0, "Already stopped" kill_cmd = "kill " + " ".join(pids) hutil.log(kill_cmd) RunGetOutput(kill_cmd) terminated = False num_checked = 0 while not terminated and num_checked < 10: time.sleep(2) num_checked += 1 pids = get_lad_pids() if not pids: hutil.log("stop_mdsd(): All processes successfully terminated") terminated = True else: hutil.log("stop_mdsd() terminate check #{0}: Processes not terminated yet, rechecking in 2 seconds".format( num_checked)) if not terminated: kill_cmd = "kill -9 " + " ".join(get_lad_pids()) hutil.log("stop_mdsd(): Processes not terminated in 20 seconds. Sending SIGKILL (" + kill_cmd + ")") RunGetOutput(kill_cmd) RunGetOutput("rm " + g_lad_pids_filepath) return 0, "Terminated" if terminated else "SIGKILL'ed" def get_lad_pids(): """ Get LAD PIDs from the previously written file :return: List of 2 PIDs. One for diagnostic.py, the other for mdsd """ lad_pids = [] if not os.path.exists(g_lad_pids_filepath): return lad_pids with open(g_lad_pids_filepath, "r") as f: for pid in f.readlines(): is_still_alive = RunGetOutput("cat /proc/" + pid.strip() + "/cmdline", should_log=False)[1] if is_still_alive.find('/waagent/') > 0: lad_pids.append(pid.strip()) else: hutil.log("return not alive " + is_still_alive.strip()) return lad_pids # Issue #128 LAD should restart OMI if it crashes def restart_omi_if_crashed(omi_installed, mdsd): """ Restart OMI if it crashed. Called from the main monitoring loop. :param omi_installed: bool indicating whether OMI was installed at the previous iteration. :param mdsd: Python Process object for the mdsd process, because it might need to be signaled. :return: bool indicating whether OMI was installed at this iteration (from this call) """ omicli_path = "/opt/omi/bin/omicli" omicli_noop_query_cmd = omicli_path + " noop" omi_was_installed = omi_installed # Remember the OMI install status from the last iteration omi_installed = os.path.isfile(omicli_path) if omi_was_installed and not omi_installed: hutil.log("OMI is uninstalled. This must have been intentional and externally done. " "Will no longer check if OMI is up and running.") omi_reinstalled = not omi_was_installed and omi_installed if omi_reinstalled: hutil.log("OMI is reinstalled. Will resume checking if OMI is up and running.") should_restart_omi = False if omi_installed: cmd_exit_status, cmd_output = RunGetOutput(cmd=omicli_noop_query_cmd, should_log=False) should_restart_omi = cmd_exit_status is not 0 if should_restart_omi: hutil.error("OMI noop query failed. Output: " + cmd_output + ". OMI crash suspected. " "Restarting OMI and sending SIGHUP to mdsd after 5 seconds.") omi_restart_msg = RunGetOutput("/opt/omi/bin/service_control restart")[1] hutil.log("OMI restart result: " + omi_restart_msg) time.sleep(10) # Query OMI once again to make sure restart fixed the issue. # If not, attempt to re-install OMI as last resort. cmd_exit_status, cmd_output = RunGetOutput(cmd=omicli_noop_query_cmd, should_log=False) should_reinstall_omi = cmd_exit_status is not 0 if should_reinstall_omi: hutil.error("OMI noop query failed even after OMI was restarted. Attempting to re-install the components.") configurator = create_core_components_configs() dependencies_err, dependencies_msg = setup_dependencies_and_mdsd(configurator) if dependencies_err != 0: hutil.error("Re-installing the components failed with error code: " + str(dependencies_err) + ", error message: " + dependencies_msg) return omi_installed else: omi_reinstalled = True # mdsd needs to be signaled if OMI was restarted or reinstalled because mdsd used to give up connecting to OMI # if it fails first time, and never retried until signaled. mdsd was fixed to retry now, but it's still # limited (stops retrying beyond 30 minutes or so) and backoff-ed exponentially # so it's still better to signal anyway. should_signal_mdsd = should_restart_omi or omi_reinstalled if should_signal_mdsd: omi_up_and_running = RunGetOutput(omicli_noop_query_cmd)[0] is 0 if omi_up_and_running: mdsd.send_signal(signal.SIGHUP) hutil.log("SIGHUP sent to mdsd") else: # OMI restarted but not staying up... log_msg = "OMI restarted but not staying up. Will be restarted in the next iteration." hutil.error(log_msg) # Also log this issue on syslog as well syslog.openlog('diagnostic.py', syslog.LOG_PID, syslog.LOG_DAEMON) # syslog.openlog(ident, logoption, facility) -- not taking kw args in Python 2.6 syslog.syslog(syslog.LOG_ALERT, log_msg) # syslog.syslog(priority, message) -- not taking kw args syslog.closelog() return omi_installed if __name__ == '__main__': if len(sys.argv) <= 1: print('No command line argument was specified.\nYou must be executing this program manually for testing.\n' 'In that case, one of "install", "enable", "disable", "uninstall", or "update" should be given.') else: try: main(sys.argv[1]) except Exception as e: ext_version = ET.parse('manifest.xml').find('{http://schemas.microsoft.com/windowsazure}Version').text msg = "Unknown exception thrown from diagnostic.py.\n" \ "Error: {0}\nStackTrace: {1}".format(e, traceback.format_exc()) wala_event_type = wala_event_type_for_telemetry(get_extension_operation_type(sys.argv[1])) if len(sys.argv) == 2: # Add a telemetry only if this is executed through waagent (in which # we are guaranteed to have just one cmdline arg './diagnostic -xxx'). waagent.AddExtensionEvent(name="Microsoft.Azure.Diagnostic.LinuxDiagnostic", op=wala_event_type, isSuccess=False, version=ext_version, message=msg) else: # Trick to print backtrace in case we execute './diagnostic.py -xxx yyy' from a terminal for testing. # By just adding one more cmdline arg with any content, the above if condition becomes false,\ # thus allowing us to run code here, printing the exception message with the stack trace. print(msg) # Need to exit with an error code, so that this situation can be detected by waagent and also # reported to customer through agent/extension status blob. hutil.do_exit(42, wala_event_type, 'Error', '42', msg) # What's 42? Ask Abhi. ================================================ FILE: Diagnostic/lad_config_all.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS # OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import os import traceback import xml.etree.ElementTree as ET import Providers.Builtin as BuiltIn import Utils.ProviderUtil as ProvUtil import Utils.LadDiagnosticUtil as LadUtil import Utils.XmlUtil as XmlUtil import Utils.mdsd_xml_templates as mxt import telegraf_utils.telegraf_config_handler as telhandler import metrics_ext_utils.metrics_constants as metrics_constants import metrics_ext_utils.metrics_ext_handler as me_handler from Utils.lad_exceptions import LadLoggingConfigException, LadPerfCfgConfigException from Utils.lad_logging_config import LadLoggingConfig, copy_source_mdsdevent_eh_url_elems from Utils.misc_helpers import get_storage_endpoints_with_account, escape_nonalphanumerics class LadConfigAll: """ A class to generate configs for all 3 core components of LAD: mdsd, omsagent (fluentd), and syslog (rsyslog or syslog-ng) based on LAD's JSON extension settings. The mdsd XML config file generated will be /var/lib/waagent/Microsoft. ...-x.y.zzzz/xmlCfg.xml (hard-coded). Other config files whose contents are generated by this class are as follows: - /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/syslog.conf : fluentd's syslog source config - /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/tail.conf : fluentd's tail source config (fileLogs) - /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/z_out_mdsd.conf : fluentd's out_mdsd out plugin config - /etc/rsyslog.conf or /etc/rsyslog.d/95-omsagent.conf: rsyslog config for LAD's syslog settings The content should be appended to the corresponding file, not overwritten. After that, the file should be processed so that the '%SYSLOG_PORT%' pattern is replaced with the assigned TCP port number. - /etc/syslog-ng.conf: syslog-ng config for LAD's syslog settings. The content should be appended, not overwritten. """ _default_perf_cfgs = [ {"query": "SELECT PercentAvailableMemory, AvailableMemory, UsedMemory, PercentUsedSwap " "FROM SCX_MemoryStatisticalInformation", "table": "LinuxMemory"}, {"query": "SELECT PercentProcessorTime, PercentIOWaitTime, PercentIdleTime " "FROM SCX_ProcessorStatisticalInformation WHERE Name='_TOTAL'", "table": "LinuxCpu"}, {"query": "SELECT AverageWriteTime,AverageReadTime,ReadBytesPerSecond,WriteBytesPerSecond " "FROM SCX_DiskDriveStatisticalInformation WHERE Name='_TOTAL'", "table": "LinuxDisk"} ] def __init__(self, ext_settings, ext_dir, waagent_dir, deployment_id, fetch_uuid, encrypt_string, logger_log, logger_error): """ Constructor. :param ext_settings: A LadExtSettings (in Utils/lad_ext_settings.py) obj wrapping the Json extension settings. :param ext_dir: Extension directory (e.g., /var/lib/waagent/Microsoft.OSTCExtensions.LinuxDiagnostic-2.3.xxxx) :param waagent_dir: WAAgent directory (e.g., /var/lib/waagent) :param deployment_id: Deployment ID string (or None) that should be obtained & passed by the caller from waagent's HostingEnvironmentCfg.xml. :param fetch_uuid: A function which fetches the UUID for the VM :param encrypt_string: A function which encrypts a string, given a cert_path :param logger_log: Normal logging function (e.g., hutil.log) that takes only one param for the logged msg. :param logger_error: Error logging function (e.g., hutil.error) that takes only one param for the logged msg. """ self._ext_settings = ext_settings self._ext_dir = ext_dir self._waagent_dir = waagent_dir self._deployment_id = deployment_id self._fetch_uuid = fetch_uuid self._encrypt_secret = encrypt_string self._logger_log = logger_log self._logger_error = logger_error self._telegraf_me_url = metrics_constants.lad_metrics_extension_influx_udp_url self._telegraf_mdsd_url = metrics_constants.telegraf_influx_url # Generated logging configs place holders self._fluentd_syslog_src_config = None self._fluentd_tail_src_config = None self._fluentd_out_mdsd_config = None self._rsyslog_config = None self._syslog_ng_config = None self._telegraf_config = None self._telegraf_namespaces = None self._mdsd_config_xml_tree = ET.ElementTree(ET.fromstring(mxt.entire_xml_cfg_tmpl)) self._sink_configs = LadUtil.SinkConfiguration() self._sink_configs.insert_from_config(self._ext_settings.read_protected_config('sinksConfig')) # Reading the AzMonSink info from the public config. self._sink_configs_public = LadUtil.SinkConfiguration() self._sink_configs_public.insert_from_config(self._ext_settings.read_public_config('sinksConfig')) # If we decide to also read sinksConfig from ladCfg, do it first, so that private settings override # Get encryption settings handlerSettings = ext_settings.get_handler_settings() if handlerSettings['protectedSettings'] is None: errorMsg = "Settings did not contain protectedSettings. For information on protected settings, " \ "visit https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#protected-settings." self._logger_error(errorMsg) raise LadLoggingConfigException(errorMsg) if handlerSettings['protectedSettingsCertThumbprint'] is None: errorMsg = "Settings did not contain protectedSettingsCertThumbprint. For information on protected settings, " \ "visit https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#protected-settings." self._logger_error(errorMsg) raise LadLoggingConfigException(errorMsg) thumbprint = handlerSettings['protectedSettingsCertThumbprint'] self._cert_path = os.path.join(waagent_dir, thumbprint + '.crt') self._pkey_path = os.path.join(waagent_dir, thumbprint + '.prv') def _ladCfg(self): return self._ext_settings.read_public_config('ladCfg') @staticmethod def _wad_table_name(interval): """ Build the name and storetype of a metrics table based on the aggregation interval and presence/absence of sinks :param str interval: String representation of aggregation interval :return: table name :rtype: str """ return 'WADMetrics{0}P10DV2S'.format(interval) def _add_element_from_string(self, path, xml_string, add_only_once=True): """ Add an XML fragment to the mdsd config document in accordance with path :param str path: Where to add the fragment :param str xml_string: A string containing the XML element to add :param bool add_only_once: Indicates whether to perform the addition only to the first match of the path. """ XmlUtil.addElement(xml=self._mdsd_config_xml_tree, path=path, el=ET.fromstring(xml_string), addOnlyOnce=add_only_once) def _add_element_from_element(self, path, xml_elem, add_only_once=True): """ Add an XML fragment to the mdsd config document in accordance with path :param str path: Where to add the fragment :param ElementTree xml_elem: An ElementTree object XML fragment that should be added to the path. :param bool add_only_once: Indicates whether to perform the addition only to the first match of the path. """ XmlUtil.addElement(xml=self._mdsd_config_xml_tree, path=path, el=xml_elem, addOnlyOnce=add_only_once) def _add_derived_event(self, interval, source, event_name, store_type, add_lad_query=False): """ Add a element to the configuration :param str interval: Interval at which this DerivedEvent should be run :param str source: Local table from which this DerivedEvent should pull :param str event_name: Destination table to which this DerivedEvent should push :param str store_type: The storage type of the destination table, e.g. Local, Central, JsonBlob :param bool add_lad_query: True if a subelement should be added to this element """ derived_event = mxt.derived_event.format(interval=interval, source=source, target=event_name, type=store_type) element = ET.fromstring(derived_event) if add_lad_query: XmlUtil.addElement(element, ".", ET.fromstring(mxt.lad_query)) self._add_element_from_element('Events/DerivedEvents', element) def _add_obo_field(self, name, value): """ Add an element to the element. :param name: Name of the field :param value: Value for the field """ self._add_element_from_string('Management', mxt.obo_field.format(name=name, value=value)) def _update_metric_collection_settings(self, ladCfg, namespaces): """ Update mdsd_config_xml_tree for Azure Portal metric collection. This method builds the necessary aggregation queries that grind the ingested data and push it to the WADmetric table. :param ladCfg: ladCfg object from extension config :param namespaces: list of telegraf plugins sources obtained after parsing lad metrics config :return: None """ # Aggregation is done by within a . If there are no alternate sinks, the DerivedQuery # can send output directly to the WAD metrics table. If there *are* alternate sinks, have the LADQuery send # output to a new local table, then arrange for additional derived queries to pull from that. intervals = LadUtil.getAggregationPeriodsFromLadCfg(ladCfg) sinks = LadUtil.getFeatureWideSinksFromLadCfg(ladCfg, 'performanceCounters') for plugin in namespaces: lad_specific_storage_plugin = "storage-" + plugin for aggregation_interval in intervals: if sinks: local_table_name = ProvUtil.MakeUniqueEventName('aggregationLocal') self._add_derived_event(aggregation_interval, lad_specific_storage_plugin, local_table_name, 'Local', add_lad_query=True) self._handle_alternate_sinks(aggregation_interval, sinks, local_table_name) else: self._add_derived_event(aggregation_interval, lad_specific_storage_plugin, LadConfigAll._wad_table_name(aggregation_interval), 'Central', add_lad_query=True) def _handle_alternate_sinks(self, interval, sinks, source): """ Update the XML config to accommodate alternate data sinks. Start by pumping the data from the local source to the actual wad table; then run through the sinks and add annotations or additional DerivedEvents as needed. :param str interval: Aggregation interval :param [str] sinks: List of alternate destinations :param str source: Name of local table from which data is to be pumped :return: """ self._add_derived_event(interval, source, LadConfigAll._wad_table_name(interval), 'Central') for name in sinks: sink = self._sink_configs.get_sink_by_name(name) if sink is None: self._logger_log("Ignoring sink '{0}' for which no definition was found".format(name)) elif sink['type'] == 'EventHub': if 'sasURL' in sink: self._add_streaming_annotation(source, sink['sasURL']) else: self._logger_error("Ignoring EventHub sink '{0}': no 'sasURL' was supplied".format(name)) elif sink['type'] == 'JsonBlob': self._add_derived_event(interval, source, name, 'JsonBlob') else: self._logger_log("Ignoring sink '{0}': unknown type '{1}'".format(name, sink['type'])) def _add_streaming_annotation(self, sink_name, sas_url): """ Helper to add an EventStreamingAnnotation element for the given sink_name and sas_url :param str sink_name: Name of the EventHub sink name for the SAS URL :param str sas_url: Raw SAS URL string for the EventHub sink """ self._add_element_from_string('EventStreamingAnnotations', mxt.per_eh_url_tmpl.format(eh_name=sink_name, key_path=self._pkey_path, enc_eh_url=self._encrypt_secret_with_cert(sas_url))) def _encrypt_secret_with_cert(self, secret): """ update_account_settings() helper. :param secret: Secret to encrypt :return: Encrypted secret string. None if openssl command exec fails. """ return self._encrypt_secret(self._cert_path, secret) def _update_account_settings(self, account, token, endpoints): """ Update the MDSD configuration Account element with Azure table storage properties. Exactly one of (key, token) must be provided. :param account: Storage account to which LAD should write data :param token: SAS token to access the storage account :param endpoints: Identifies the Azure storage endpoints (public or specific sovereign cloud) where the storage account is """ assert token, "Token must be given." assert self._mdsd_config_xml_tree is not None token = self._encrypt_secret_with_cert(token) assert token, "Could not encrypt token" XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Accounts/SharedAccessSignature', "account", account, ['isDefault', 'true']) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Accounts/SharedAccessSignature', "key", token, ['isDefault', 'true']) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Accounts/SharedAccessSignature', "decryptKeyPath", self._pkey_path, ['isDefault', 'true']) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Accounts/SharedAccessSignature', "tableEndpoint", endpoints[0], ['isDefault', 'true']) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Accounts/SharedAccessSignature', "blobEndpoint", endpoints[1], ['isDefault', 'true']) XmlUtil.removeElement(self._mdsd_config_xml_tree, 'Accounts', 'Account') def _set_xml_attr(self, key, value, xml_path, selector=[]): """ Set XML attribute on the element specified with xml_path. :param key: The attribute name to set on the XML element. :param value: The default value to be set, if there's no public config for that attribute. :param xml_path: The path of the XML element(s) to which the attribute is applied. :param selector: Selector for finding the actual XML element (see XmlUtil.setXmlValue) :return: None. Change is directly applied to mdsd_config_xml_tree XML member object. """ assert self._mdsd_config_xml_tree is not None v = self._ext_settings.read_public_config(key) if not v: v = value XmlUtil.setXmlValue(self._mdsd_config_xml_tree, xml_path, key, v, selector) def _set_event_volume(self, lad_cfg): """ Set event volume in mdsd config. Check if desired event volume is specified, first in ladCfg then in public config. If in neither then default to Medium. :param lad_cfg: 'ladCfg' Json object to look up for the event volume setting. :return: None. The mdsd config XML tree's eventVolume attribute is directly updated. :rtype: str """ assert self._mdsd_config_xml_tree is not None event_volume = LadUtil.getEventVolumeFromLadCfg(lad_cfg) if event_volume: self._logger_log("Event volume found in ladCfg: " + event_volume) else: event_volume = self._ext_settings.read_public_config("eventVolume") if event_volume: self._logger_log("Event volume found in public config: " + event_volume) else: event_volume = "Medium" self._logger_log("Event volume not found in config. Using default value: " + event_volume) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, "Management", "eventVolume", event_volume) ###################################################################### # This is the main API that's called by user. All other methods are # actually helpers for this, thus made private by convention. ###################################################################### def generate_all_configs(self): """ Generates configs for all components required by LAD. Generates XML cfg file for mdsd, from JSON config settings (public & private). Also generates rsyslog/syslog-ng configs corresponding to 'syslogEvents' or 'syslogCfg' setting. Also generates fluentd's syslog/tail src configs and out_mdsd configs. The rsyslog/syslog-ng and fluentd configs are not yet saved to files. They are available through the corresponding getter methods of this class (get_fluentd_*_config(), get_*syslog*_config()). Returns (True, '') if config was valid and proper xmlCfg.xml was generated. Returns (False, '...') if config was invalid and the error message. """ # 1. Add DeploymentId (if available) to identity columns if self._deployment_id: XmlUtil.setXmlValue(self._mdsd_config_xml_tree, "Management/Identity/IdentityComponent", "", self._deployment_id, ["name", "DeploymentId"]) # 2. Generate telegraf, MetricsExtension, omsagent (fluentd) configs, rsyslog/syslog-ng config, and update corresponding mdsd config XML try: lad_cfg = self._ladCfg() if not lad_cfg: return False, 'Unable to find Ladcfg element. Failed to generate configs for fluentd, syslog, and mdsd ' \ '(see extension error logs for more details)' syslogEvents_setting = self._ext_settings.get_syslogEvents_setting() fileLogs_setting = self._ext_settings.get_fileLogs_setting() lad_logging_config_helper = LadLoggingConfig(syslogEvents_setting, fileLogs_setting, self._sink_configs, self._pkey_path, self._cert_path, self._encrypt_secret) mdsd_syslog_config = lad_logging_config_helper.get_mdsd_syslog_config(self._ext_settings.read_protected_config('disableStorageAccount') == True) mdsd_filelog_config = lad_logging_config_helper.get_mdsd_filelog_config() copy_source_mdsdevent_eh_url_elems(self._mdsd_config_xml_tree, mdsd_syslog_config) copy_source_mdsdevent_eh_url_elems(self._mdsd_config_xml_tree, mdsd_filelog_config) self._fluentd_syslog_src_config = lad_logging_config_helper.get_fluentd_syslog_src_config() self._fluentd_tail_src_config = lad_logging_config_helper.get_fluentd_filelog_src_config() self._fluentd_out_mdsd_config = lad_logging_config_helper.get_fluentd_out_mdsd_config() self._rsyslog_config = lad_logging_config_helper.get_rsyslog_config() self._syslog_ng_config = lad_logging_config_helper.get_syslog_ng_config() parsed_perf_settings = lad_logging_config_helper.parse_lad_perf_settings(lad_cfg) if len(parsed_perf_settings) > 0: self._telegraf_config, self._telegraf_namespaces = telhandler.handle_config(parsed_perf_settings, self._telegraf_me_url, self._telegraf_mdsd_url, True) #Handle the EH, JsonBlob and AzMonSink logic self._update_metric_collection_settings(lad_cfg, self._telegraf_namespaces) mdsd_telegraf_config = lad_logging_config_helper.get_mdsd_telegraf_config(self._telegraf_namespaces) copy_source_mdsdevent_eh_url_elems(self._mdsd_config_xml_tree, mdsd_telegraf_config) resource_id = self._ext_settings.get_resource_id() if resource_id: # Set JsonBlob sink-related elements uuid_for_instance_id = self._fetch_uuid() self._add_obo_field(name='resourceId', value=resource_id) self._add_obo_field(name='agentIdentityHash', value=uuid_for_instance_id) XmlUtil.setXmlValue(self._mdsd_config_xml_tree, 'Events/DerivedEvents/DerivedEvent/LADQuery', 'partitionKey', escape_nonalphanumerics(resource_id)) lad_query_instance_id = "" if resource_id.find("providers/Microsoft.Compute/virtualMachineScaleSets") >= 0: lad_query_instance_id = uuid_for_instance_id self._set_xml_attr("instanceID", lad_query_instance_id, "Events/DerivedEvents/DerivedEvent/LADQuery") else: self._logger_log('Unable to find resource id in the config. Failed to generate configs for Metrics in mdsd ' \ '(see extension error logs for more details)') #Only enable Metrics if AzMonSink is in the config azmonsink = self._sink_configs_public.get_sink_by_name("AzMonSink") if azmonsink is None: self._logger_log("Did not find AzMonSink in public config. Will not set up custom metrics through ME.") else: self._logger_log("Found AzMonSink in public config. Setting up custom metrics through ME.") me_handler.setup_me(True) except Exception as e: self._logger_error("Failed to create omsagent (fluentd), rsyslog/syslog-ng configs, telegraf config or to update " "corresponding mdsd config XML. Error: {0}\nStacktrace: {1}" .format(e, traceback.format_exc())) return False, 'Failed to generate configs for fluentd, syslog, and mdsd; see extension.log for more details.' # 3. Before starting to update the storage account settings, log extension's entire settings # with secrets redacted, for diagnostic purpose. self._ext_settings.log_ext_settings_with_secrets_redacted(self._logger_log, self._logger_error) # 4. Actually update the storage account settings on mdsd config XML tree (based on extension's # protectedSettings). account = self._ext_settings.read_protected_config('storageAccountName').strip() if not account: return False, "Configuration Error: Must specify storageAccountName in protected settings. For information on protected settings, " \ "visit https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#protected-settings." if self._ext_settings.read_protected_config('storageAccountKey'): return False, "Configuration Error: The storageAccountKey protected setting is deprecated in LAD 3.0 and cannot be used. " \ "Instead, use the storageAccountSasToken setting. For documentation of this setting and instructions for generating " \ "a SAS token, visit https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#protected-settings." token = self._ext_settings.read_protected_config('storageAccountSasToken').strip() if not token or token == '?': return False, "Configuration Error: Must specify storageAccountSasToken in the protected settings. For documentation of this setting and instructions " \ "for generating a SAS token, visit https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/diagnostics-linux#protected-settings." if '?' == token[0]: token = token[1:] endpoints = get_storage_endpoints_with_account(account, self._ext_settings.read_protected_config('storageAccountEndPoint')) self._update_account_settings(account, token, endpoints) # 5. Update mdsd config XML's eventVolume attribute based on the logic specified in the helper. self._set_event_volume(lad_cfg) # 6. Finally generate mdsd config XML file out of the constructed XML tree object. self._mdsd_config_xml_tree.write(os.path.join(self._ext_dir, 'xmlCfg.xml')) return True, "" @staticmethod def __throw_if_output_is_none(output): """ Helper to check if output is already generated (not None) and throw if it's not (None). :return: None """ if output is None: raise LadLoggingConfigException('LadConfigAll.get_*_config() should be called after ' 'LadConfigAll.generate_mdsd_omsagent_syslog_config() is called') def get_fluentd_syslog_src_config(self): """ Returns the obtained Fluentd's syslog src config. This getter (and all that follow) should be called after self.generate_mdsd_omsagent_syslog_config() is called. The return value should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/syslog.conf after replacing '%SYSLOG_PORT%' with the assigned TCP port number. :rtype: str :return: Fluentd syslog src config string """ LadConfigAll.__throw_if_output_is_none(self._fluentd_syslog_src_config) return self._fluentd_syslog_src_config def get_fluentd_tail_src_config(self): """ Returns the obtained Fluentd's tail src config. This getter (and all that follow) should be called after self.generate_mdsd_omsagent_syslog_config() is called. The return value should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/tail.conf. :rtype: str :return: Fluentd tail src config string """ LadConfigAll.__throw_if_output_is_none(self._fluentd_tail_src_config) return self._fluentd_tail_src_config def get_fluentd_out_mdsd_config(self): """_fluentd_out_mdsd_config Returns the obtained Fluentd's out_mdsd config. This getter (and all that follow) should be called after self.generate_mdsd_omsagent_syslog_config() is called. The return value should be overwritten to /etc/opt/microsoft/omsagent/LAD/conf/omsagent.d/z_out_mdsd.conf. :rtype: str :return: Fluentd out_mdsd config string """ LadConfigAll.__throw_if_output_is_none(self._fluentd_out_mdsd_config) return self._fluentd_out_mdsd_config def get_rsyslog_config(self): """ Returns the obtained rsyslog config. This getter (and all that follow) should be called after self.generate_mdsd_omsagent_syslog_config() is called. The return value should be appended to /etc/rsyslog.d/95-omsagent.conf if rsyslog ver is new (that is, if /etc/rsyslog.d/ exists). It should be appended to /etc/rsyslog.conf if rsyslog ver is old (no /etc/rsyslog.d/). The appended file (either /etc/rsyslog.d/95-omsagent.conf or /etc/rsyslog.conf) should be processed so that the '%SYSLOG_PORT%' pattern in the file is replaced with the assigned TCP port number. :rtype: str :return: rsyslog config string """ LadConfigAll.__throw_if_output_is_none(self._rsyslog_config) return self._rsyslog_config def get_syslog_ng_config(self): """ Returns the obtained syslog-ng config. This getter (and all that follow) should be called after self.generate_mdsd_omsagent_syslog_config() is called. The return value should be appended to /etc/syslog-ng.conf. The appended file (/etc/syslog-ng.conf) should be processed so that the '%SYSLOG_PORT%' pattern in the file is replaced with the assigned TCP port number. :rtype: str :return: syslog-ng config string """ LadConfigAll.__throw_if_output_is_none(self._syslog_ng_config) return self._syslog_ng_config ================================================ FILE: Diagnostic/lad_mdsd.te ================================================ # SELinux policy for mdsd on LAD, obtained by "grep mdsd /var/log/audit/audit.log | audit2allow -m lad_mdsd.te" # Note it combines different types (unconfined_t and initrc_t) to support both Redhat policy and CentOS policy module lad_mdsd 1.0; require { type unconfined_t; type initrc_t; type syslogd_t; type var_run_t; class sock_file write; class unix_stream_socket connectto; } #============= syslogd_t ============== allow syslogd_t unconfined_t:unix_stream_socket connectto; allow syslogd_t initrc_t:unix_stream_socket connectto; allow syslogd_t var_run_t:sock_file write; ================================================ FILE: Diagnostic/license.txt ================================================ Linux Azure Diagnostic Extension v.2.3.9 Copyright (c) Microsoft Corporation All rights reserved. MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Diagnostic/manifest.xml ================================================ Microsoft.Azure.Diagnostics LinuxDiagnostic 4.1.12 VmRole Microsoft Azure Diagnostic Extension for Linux Virtual Machines true https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx https://github.com/Azure/azure-linux-extensions true Linux Microsoft ================================================ FILE: Diagnostic/mdsd/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 2.6) project(mdsd) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") # Platform (not compiler) specific settings if(UNIX) # This includes Linux message("Build for Unix/Linux OS") else() message("-- Unsupported Build Platform.") endif() # Compiler (not platform) specific settings if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") message("-- Setting clang options") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") set(LINKSTDLIB "c++") set(LIBSUFFIX "-clang") set(WARNINGS "${WARNINGS} -Wno-deprecated-register") elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") message("-- Setting gcc options") set(WARNINGS "${WARNINGS} -Wno-unused-local-typedefs") else() message("-- Unknown compiler, success is doubtful.") endif() # To turn off the option from cmdline, run: cmake -DBUILD_TESTS=OFF ... option(BUILD_TESTS "Build tests." ON) # To add code coverage build options option(BUILD_COV "Build with code coverage." OFF) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") # Common flags for both C and C++ set(COMM_FLAGS "${COMM_FLAGS} -fstack-protector-all") set(COMM_FLAGS "${COMM_FLAGS} -fPIC") set(COMM_FLAGS "${COMM_FLAGS} -D_FORTIFY_SOURCE=2") set(COMM_FLAGS "${COMM_FLAGS} -ffunction-sections") if(BUILD_COV) set(COMM_FLAGS "${COMM_FLAGS} -fprofile-arcs") set(COMM_FLAGS "${COMM_FLAGS} -ftest-coverage") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMM_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COMM_FLAGS}") set(WARNINGS "${WARNINGS} -Wall") set(WARNINGS "${WARNINGS} -Wextra") set(WARNINGS "${WARNINGS} -Wno-unknown-pragmas") set(WARNINGS "${WARNINGS} -Wno-unused-parameter") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${WARNINGS}") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb") set(LINKER_FLAGS "-Wl,-z,relro -Wl,-z,now") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${LINKER_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${LINKER_FLAGS}") # Build static library only option(BUILD_SHARED_LIBS "Build shared Libraries." OFF) set(OMI_INCLUDE_DIRS /usr/include/omi /usr/include/omi/common /usr/include/omi/output/include /usr/include/omi/micxx ) set(OMI_LIB_PATH "/opt/omi/lib") set(CASABLANCA_INCLUDE_DIRS "/usr/include/cpprest") set(CASABLANCA_LIBRARIES "/usr/lib/x86_64-linux-gnu/libcpprest${LIBSUFFIX}.a") set(STORAGE_INCLUDE_DIRS "/usr/include/azurestorage") set(STORAGE_LIBRARIES "/usr/lib/x86_64-linux-gnu/libazurestorage${LIBSUFFIX}.a") set(MDSD_LIB_NAME mdsd-lib${LIBSUFFIX}) set(LOG_LIB_NAME mdsdlog${LIBSUFFIX}) set(UTIL_LIB_NAME mdsdutil${LIBSUFFIX}) set(CMD_LIB_NAME mdscommands${LIBSUFFIX}) set(INPUT_LIB_NAME mdsdinput${LIBSUFFIX}) set(MDSDCFG_LIB_NAME mdsdcfg${LIBSUFFIX}) set(MDSREST_LIB_NAME mdsrest${LIBSUFFIX}) # Set rpath for all executables including mdsd, tests, etc SET(CMAKE_SKIP_BUILD_RPATH FALSE) SET(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE) SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) add_subdirectory(mdsdlog) add_subdirectory(mdsdutil) add_subdirectory(mdscommands) add_subdirectory(mdsdinput) add_subdirectory(mdsdcfg) add_subdirectory(mdsrest) add_subdirectory(mdsd) ================================================ FILE: Diagnostic/mdsd/Dockerfile ================================================ FROM ubuntu:trusty RUN apt-get update && apt-get install -y software-properties-common RUN apt-get update && \ apt-get install -y sudo apt-utils openssh-server wget unzip git build-essential libtool && \ apt-get upgrade -y && apt-get dist-upgrade -y EXPOSE 22 RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-4.8 50 && \ update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-4.8 50 RUN apt-get update && \ apt-get install -y psmisc libxml++2.6-dev uuid-dev python-software-properties zlib1g-dev \ libssl1.0.0 libssl-dev cmake rpm liblzma-dev libjson-c-dev libjson-c2 RUN apt-get update && ver=1.55 && \ apt-get install -y libboost$ver-dev libboost-system$ver-dev libboost-thread$ver-dev \ libboost-filesystem$ver-dev libboost-random$ver-dev libboost-locale$ver-dev \ libboost-regex$ver-dev libboost-iostreams$ver-dev libboost-log$ver-dev RUN apt-get update && ver=1.55.0 && \ apt-get install -y libboost-system$ver libboost-thread$ver libboost-filesystem$ver \ libboost-random$ver libboost-locale$ver libboost-regex$ver \ libboost-iostreams$ver libboost-log$ver ADD azure.list /etc/apt/sources.list.d/azure.list RUN apt-key adv --keyserver packages.microsoft.com --recv-keys B02C46DF417A0893 && \ apt-get install apt-transport-https RUN apt-get update && \ apt-get install -y libcpprest-dev libazurestorage-dev libomi-dev libcpprest \ libazurestorage omi libbond-dev ================================================ FILE: Diagnostic/mdsd/LICENSE.txt ================================================ ------------------------------------------ START OF LICENSE ----------------------------------------- Linux mdsd Agent Copyright (c) Microsoft Corporation All rights reserved.  MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ----------------------------------------------- END OF LICENSE ------------------------------------------ ================================================ FILE: Diagnostic/mdsd/README.md ================================================ # mdsd agent The mdsd agent is the workhorse binary for the Linux Diagnostic Extension. The LAD extension constructs an mdsd configuration file based on the LAD configuration. ## Dependencies The Dockerfile defines an environment sufficient to build the mdsd binary. Most dependencies are satisfied by the Ubuntu "trusty" repositories. The exceptions are for open-source components released by Microsoft. These components are available in source form from github. For convenience, Microsoft has made installable .deb packages available in a public "azurecore" repository, which the Dockerfile references. These components are: - [CPPrest, a.k.a. "Casablanca"](https://github.com/Microsoft/cpprestsdk) - [Azure Storage SDK](https://github.com/Azure/azure-storage-cpp) - [Microsoft bond](https://github.com/Microsoft/bond) - [Open Management Infrastructure (OMI)](https://github.com/Microsoft/omi) ## Building the program Run `buildcmake.sh` with the appropriate options. This will build all the necessary Makefiles, then build the program, then construct .deb and .rpm packages containing the built binary. For maximum portability across distros, the mdsd binary is built to use static libraries whenever possible. Build artifacts are dropped under `builddir` (which is symlinked to the actual directory hierarchy, which will differ based on the choice of debug vs optimized build). The release packages appear under the `lad-mdsd` directory. ## Future direction Over time, the capabilities of this monolithic binary are likely be broken out into fluentd plug-ins. This will significantly reduce the amount of code involved and will enable more flexible growth of the LAD extension. ================================================ FILE: Diagnostic/mdsd/SampleConfig-LAD-SAS.xml ================================================ c9c8552dc3a1421da8ecc0c284082a39 \Memory\AvailableMemory \Memory\PercentAvailableMemory \Memory\UsedMemory \Memory\PercentUsedMemory \Memory\PercentUsedByCache \Memory\PagesPerSec \Memory\PagesReadPerSec \Memory\PagesWrittenPerSec \Memory\AvailableSwap \Memory\PercentAvailableSwap \Memory\UsedSwap \Memory\PercentUsedSwap \Processor\PercentIdleTime \Processor\PercentUserTime \Processor\PercentNiceTime \Processor\PercentPrivilegedTime \Processor\PercentInterruptTime \Processor\PercentDPCTime \Processor\PercentProcessorTime \Processor\PercentIOWaitTime \PhysicalDisk\BytesPerSecond \PhysicalDisk\ReadBytesPerSecond \PhysicalDisk\WriteBytesPerSecond \PhysicalDisk\TransfersPerSecond \PhysicalDisk\ReadsPerSecond \PhysicalDisk\WritesPerSecond \PhysicalDisk\AverageReadTime \PhysicalDisk\AverageWriteTime \PhysicalDisk\AverageTransferTime \PhysicalDisk\AverageDiskQueueLength \NetworkInterface\BytesTransmitted \NetworkInterface\BytesReceived \NetworkInterface\PacketsTransmitted \NetworkInterface\PacketsReceived \NetworkInterface\BytesTotal \NetworkInterface\TotalRxErrors \NetworkInterface\TotalTxErrors \NetworkInterface\TotalCollisions ================================================ FILE: Diagnostic/mdsd/azure.list ================================================ deb [arch=amd64] https://packages.microsoft.com/repos/azurecore/ trusty main ================================================ FILE: Diagnostic/mdsd/buildcmake.sh ================================================ #!/bin/bash # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. # This will build mdsd and its libraries # Usage: see Usage() # TotalErrors=0 BuildType= CCompiler=gcc CXXCompiler=g++ BuildName=dev BUILDDIR=builddir MakeFileOnly=0 Parallelism="-j4" # If CodeCoverage=1, build with code coverage options. # NOTE: only gcc is supported and it must be debug build. CodeCoverage=OFF Usage() { echo "Usage: $0 <-a> | <-d|-o> <-c|-g> [-b buildname] [-mC] [-p parallelism] [-s] [-t]" echo " -b: use buildname. Default: timestamp." echo " -C: capture code coverage." echo " -d: build debug build." echo " -m: create makefiles only. After done, run 'make help' for options." echo " -o: build optimized(release) build." echo " -p: specify number of parallel compile operations (default 4)." } if [ "$#" == "0" ]; then Usage exit 1 fi args=`getopt b:Cdhmop: $*` if [ $? != 0 ]; then Usage exit 1 fi set -- $args for i; do case "$i" in -b) BuildName=$2 shift ; shift ;; -C) CodeCoverage=ON shift ;; -d) if [ -z "${BuildType}" ]; then BuildType=d else echo "Error: build type is already set to be ${BuildType}." exit 1 fi shift ;; -h) Usage exit 0 shift ;; -m) MakeFileOnly=1 shift ;; -o) if [ -z "${BuildType}" ]; then BuildType=o else echo "Error: build type is already set to be ${BuildType}." exit 1 fi shift ;; -p) declare -i numJobs # This variable is an integer, guaranteed by the shell numJobs=$2 if [ $numJobs -gt 1 ]; then Parallelism="-j$numJobs" echo "Setting parallelism to $Parallelism" else Parallelism="" echo "Disabling parallel compilation" fi shift; shift ;; --) shift; break ;; esac done if [ -z "${BuildType}" ]; then echo "Error: missing build type. -d or -o is required." exit 1 fi if [ "${CodeCoverage}" == "ON" ]; then if [ "${BuildType}" != "d" ]; then echo "Error: only debug build is supported for code coverage." exit 1 fi fi BuildWithCMake() { echo echo Start to build source code. BuildType=${BuildType} ... BinDropDir=${BUILDDIR}.${BuildType}.${CCompiler} rm -rf ${BUILDDIR} ${BinDropDir} mkdir ${BinDropDir} ln -s ${BinDropDir} ${BUILDDIR} pushd ${BinDropDir} DefBuildNumber= if [ ! -z "${BuildName}" ]; then DefBuildNumber=-DBUILD_NUMBER=${BuildName} fi echo "BuildName: '${DefBuildNumber}'" CMakeBuildType="Release" if [ ${BuildType} == "d" ]; then CMakeBuildType="Debug" fi cmake -DCMAKE_C_COMPILER=${CCompiler} -DCMAKE_CXX_COMPILER=${CXXCompiler} \ -DCMAKE_BUILD_TYPE=${CMakeBuildType} ${DefBuildNumber} \ -DBUILD_COV=${CodeCoverage} ../ CheckCmdError "cmake" if [ ${MakeFileOnly} != 0 ]; then echo echo Makfiles are created. To make, cd ${BUILDDIR}, run make \. echo make help exit ${TotalErrors} fi make ${Parallelism} CheckCmdError "make ${Parallelism}" make install CheckCmdError "make install" if [ ${CCompiler} == "gcc" ]; then # Make deb/rpm packages for LAD mdsd make -C ../lad-mdsd/deb LABEL=${BuildName} CheckCmdError "lad-mdsd/deb" make -C ../lad-mdsd/rpm LABEL=${BuildName} CheckCmdError "lad-mdsd/rpm" fi tar czf release.tar.gz release popd } # Check whether previous command has error or not. # Usage: CheckCmdError "description" CheckCmdError() { if [ $? != 0 ]; then let TotalErrors+=1 echo Error: build $1 failed exit ${TotalErrors} else echo Finished building $1 successfully fi } # Usage: ParseGlibcVer (optional) ParseGlibcVer() { # Maximum GLIBC version supported by oldest supported distro glibcver=2.15 ParserScript=./parseglibc.py dirname=$1 filename=$2 # optional, can be NULL echo if [ -n "${filename}" ]; then echo python ${ParserScript} -f ${dirname}/${filename} -v ${glibcver} python ${ParserScript} -f ${dirname}/${filename} -v ${glibcver} else echo python ${ParserScript} -d ${dirname} -v ${glibcver} python ${ParserScript} -d ${dirname} -v ${glibcver} fi if [ $? != 0 ]; then let TotalErrors+=1 echo Error: ParseGlibcVer failed: maximum supported GLIBC version is ${glibcver}. exit ${TotalErrors} fi } # Download/build/install the appropriate version of openssl. # This is needed because the lib{ssl,crypto}.a that's available through the Ubuntu repo # is causing some link errors at the last stage. We need to use /usr/local/ssl as the # top-level OpenSSL directory for the libraries, to make them work on all distros # (especially SUSE 11, which is already done that way). BuildOpenSsl() { opensslDir=openssl-1.0.2* # Grab the only (which must be latest) OpenSSL 1.0.2 release tgzFile=$opensslDir.tar.gz wget ftp://ftp.openssl.org/source/$tgzFile || exit 1 InstallOpenSSL=1 if [ -e /usr/local/lib/libcrypto.a -a -e /usr/local/lib/libssl.a ]; then OpenSSLVersion=$(strings /usr/local/lib/libssl.a | egrep "^OpenSSL " | awk '{ print $2 }') DownloadedTGZName=$(ls $tgzFile) if [ "$DownloadedTGZName" == "openssl-$OpenSSLVersion.tar.gz" ]; then # Already latest InstallOpenSSL=0 fi fi if [ "$InstallOpenSSL" == "1" ]; then tar xfz $tgzFile cd $opensslDir # Need to make the lib*.a linkable to .so as well (for AI SDK lib*.so) by adding -fPIC. export CC="gcc -fPIC" ./config --prefix=/usr/local --openssldir=/usr/lib/ssl zlib make CheckCmdError "openssl make" sudo make install_sw CheckCmdError "openssl make install_sw" cd .. fi } echo Start build at `date`. BuildType=${BuildType} CC=${CCompiler} ... BuildOpenSsl BuildWithCMake # Remaining steps should be run only on a non-static build except ParseGlibcVer on bin build. ParseGlibcVer ./${BUILDDIR}/release/bin ParseGlibcVer ./${BUILDDIR}/release/lib echo echo Finished all builds at `date`. error = ${TotalErrors} exit ${TotalErrors} ================================================ FILE: Diagnostic/mdsd/lad-mdsd/Makefile.in.version ================================================ VERSION_NUM=1.6.100 ================================================ FILE: Diagnostic/mdsd/lad-mdsd/README.txt ================================================ This directory contains files to create the Debian package and the RPM package for the mdsd static binary executable that'll be bundled in LAD 3.0. LAD 3.0 depends on omsagent, scx, omi packages (that are installed through the omsagent shell bundle), and we shouldn't let these packages be removed when the OMS Agent for Linux extension is uninstalled (the OMS Agent extension also uses the omsagent shell bundle). The Debian/RPM packages include just the mdsd binary at /usr/local/lad/bin, and specify the dependencies. To run the Makefile on Ubuntu, the rpm package must be installed first: $ sudo apt-get install rpm Then simply run 'make' at this directory, and collect the **/lad-mdsd-*.deb and the **/lad-mdsd-*.rpm files. NOTE: Version number conventions are different on dpkg and rpm, so that's why now VERSION_NUM is separately defined in Makefile.in.version, and actual version strings are composed for different deb/rpm packaging directories. ================================================ FILE: Diagnostic/mdsd/lad-mdsd/changelog ================================================ PACKAGE (1.4.101) stable; urgency=low * Bug fix: Emit schema md5 hashes at the end of Event Hub Notification event bodies. -- Azure Linux Wed Jun 21 16:30:00 UTC 2017 PACKAGE (1.4.100) stable; urgency=low * Release mdsd binary with libraries static-linked as much as possible. Gcc-built azure-mdsd deb pkg has more libraries statically linked than clang-built azure-mdsd-clang deb pkg. * Fixed mdsd http proxy bug. -- Azure Linux Thur Jun 15 16:30:00 UTC 2017 PACKAGE (1.3.101) stable; urgency=low * Mdsd daemon pidfile is changed from /var/run/mdsd.pid to .pid, default is /var/run/mdsd/default.pid. -- Azure Linux Thu Apr 27 17:40:00 UTC 2017 PACKAGE (1.3.100) stable; urgency=low * New feature: support new store type CentralJson. Data are uploaded to Azure storage as JSON blob. * New feature: support EventHub publishing with embedded SAS keys. Data are uploaded to Azure EventHub service directly. * New feature: environment variables MDSD_CONFIG_DIR, MDSD_RUN_DIR, and MDSD_LOG_DIR. * Bug fix: print clear error when mdsd pidport file was already locked. * Bug fix: suppress transient rsyslog-mdsd OM connect() error log. * Bug fix: fix mdsd SysV script reload bug. -- Azure Linux Mon Apr 10 22:20:00 UTC 2017 PACKAGE (1.2.109) stable; urgency=low * Bug fix: parse double number properly for dynamic json data. * Bug fix: handle metadata conflicts for dynamic schema data. * Bug fix: add UNIX socket filepath length validation. * Bug fix: add EventHub max data size validation. -- Azure Linux Thur Feb 16 22:20:00 UTC 2017 PACKAGE (1.2.108) stable; urgency=low * Bug fix: retry OMI task up to 30-minute if start-up fails. * Bug fix: fix EventHub reliability issue when EventHub blob is not found. -- Azure Linux Wed Jan 11 23:20:00 UTC 2017 PACKAGE (1.2.107) stable; urgency=low * Don't load EventHub SAS keys when mdsd.xml doesn't have related storetype. * Bug fix: fix ETW event SchemaID issue. -- Azure Linux Wed Dec 6 00:40:00 UTC 2016 PACKAGE (1.2.106) stable; urgency=low * Bug fix: add EventHub blob download failure retry. * When SAS key loads fails, retry in 1-minute instead of 6-hours. * Refactor Centralbond sink and request code. * Modified bond and djson protocols so that mdsd sets PreciseTimeStamp value. * Account SAS support for LAD shared storage key. -- Azure Linux Fri Nov 11 23:40:00 UTC 2016 PACKAGE (1.2.105) stable; urgency=low * Bug fix: remove /tmp dependency for autokey downloading and parsing. * Bug fix: rm ucf in mdsd debian pkg to avoid unwanted prompt. * Improvement: FileSink open()s file lazilly, close on flush(). * Test improvement: add ingest stress tests; rm unwanted credentials. -- Azure Linux Thu Oct 20 23:40:00 UTC 2016 PACKAGE (1.2.104) stable; urgency=low * Bug fix for LAD2AI config validation crash. * Bug fix for "too many open files" error during EventHub file parsing. * Bug fix for a memory corruption in StreamListener() when ProcessLoop() has error. * Performance improvement: defer destroying entries until after unlocking LocalSink. * Restart mdsd per N-hour in cron job. * Change openssl to latest release in mdsd-static. -- Azure Linux Tue Oct 4 23:40:00 UTC 2016 PACKAGE (1.2.103) stable; urgency=low * Enable finer-grained tracing of ingest and bond. * Add sanity checks to old-style JSON ingest. * rsyslog module: enforce a maximum event size of 1MB, bigger ones will be dropped. * Bug fix: EventHub async fire-and-forget model may use object out of lifetime. * Bug fix: EventHub async task is waiting for wrong task. * Buf fix: EventHub SAS key should be set by AutoKey reload timer. * Bug fix in rsyslog module: handle partial send(); handle concurrent send(). * Bug fix in rsyslog module: handle ack msg from mdsd in separate thread to avoid livelock. * Bug fix in rsyslog module: throttle event resend to a peak of 20MBps. * Bug fix: Skip sending an ACK for an ingested event if sending it would block. -- Azure Linux Fri Sep 23 17:00:00 UTC 2016 PACKAGE (1.2.102) stable; urgency=low * Performance improvement at uploading EventHub message. * Bug fix: restore identity columns after LADQuery stage. -- Azure Linux Tue Aug 30 17:00:00 UTC 2016 PACKAGE (1.2.101) stable; urgency=low * Add unix socket support for rsyslog module -- Azure Linux Tue Aug 2 19:00:00 UTC 2016 PACKAGE (1.2.100) stable; urgency=low * Increase version number. -- Azure Linux Tue Aug 2 12:00:00 UTC 2016 PACKAGE (1.1.106) stable; urgency=low * Added support for input over unix domain sockets. * Added support for two new input encoding/protocols (bond & json) that allow dynamic schema definition -- Azure Linux Fri Jul 14 12:00:00 UTC 2016 PACKAGE (1.1.105) stable; urgency=low * Add identity columns on CentralBond as well -- Azure Linux Fri Jul 8 23:50:00 UTC 2016 PACKAGE (1.1.104) stable; urgency=low * Revert cJSON library source code due to regression (missing messages due to JSON parsing errors) -- Azure Linux Tue Jun 29 23:50:00 UTC 2016 PACKAGE (1.1.103) stable; urgency=low * Make CentralBond type to send schemas to SchemasTable. -- Azure Linux Tue Jun 28 23:50:00 UTC 2016 PACKAGE (1.1.102) stable; urgency=low * Updated cJSON source code to fix a memory corruption bug * Aborts main loop when accept() fails, to avoid spin loop. * Avoid cascaded SIGABRT handler calls that could cause deadlock on malloc/free * LocalSink lock scope improvement -- Azure Linux Mon Jun 20 23:50:00 UTC 2016 PACKAGE (1.1.101) stable; urgency=low * Add signal handler for SIGPIPE. * Fix error handling when mdsd echos back to event sender and fails. -- Azure Linux Wed Jun 15 23:50:00 UTC 2016 PACKAGE (1.1.100) stable; urgency=low * Supports remote update of agent XML primary config file from Geneva based on namespace/tenant/role/roleinstance * Supports log rotation via SIGUSR2 * Supports mapped storage monikers * Enables use of a random JSON-listener port if the requested port is unavailable * Reports actual listening port via a “pid and port” file -- Azure Linux Mon Jun 14 20:00:00 UTC 2016 PACKAGE (1.0.100) unstable; urgency=low * CentralBond support * EventHub support for some Geneva pipeline services (dgrep, kusto, cosmos/coldpath) * Statically link required libraries when possible -- Azure Linux Mon May 16 19:52:43 UTC 2016 PACKAGE (0.9.5) unstable; urgency=low * Proxy support * CPPREST 2.8, Storage C++ SDK 2.3 upgrades (for proxy support) -- Azure Linux Mon Mar 7 11:26:35 PST 2016 PACKAGE (0.9.4) unstable; urgency=low * Fix JSON parsing error when the last character in the buffer is backslash * Improve reporting of XML parse errors and warnings * Enable stack trace on crash earlier in startup * AppInsights: Add metadata for metrics and traces -- Azure Linux Fri Mar 4 19:24:45 UTC 2016 PACKAGE (0.9.3) unstable; urgency=low * Integration with hotfixed OMI/SCX -- Azure Linux Fri Jan 29 23:14:28 UTC 2016 PACKAGE (0.9.2) unstable; urgency=low * Add -C option to enable dropping a core file on fatal signal * Improve logging by adding timestamp to all logs. -- Azure Linux Thu Jan 14 18:00:00 UTC 2016 PACKAGE (0.9.1) unstable; urgency=low * Write MDS metadata table entry correctly for MDS tables with long names (see 0.8.1) * Lookup typeconverters by string instead of ustring * Show known type converters when an unsupported type conversion is requested * Improved error message at JSON event parsing against schema from config file. -- Azure Linux Wed Dec 02 23:00:00 UTC 2015 PACKAGE (0.9.0) unstable; urgency=low * Add support for AISDK library with graceful fail if not present -- Azure Linux Mon Nov 23 12:04:40 UTC 2015 PACKAGE (0.8.3) unstable; urgency=low * Fix the inverse-timestamp in shoebox rowkeys -- Azure Linux Wed Oct 21 21:04:40 UTC 2015 PACKAGE (0.8.2) unstable; urgency=low * Resolve a compatibility conflict with omazuremds.so -- Azure Linux Wed Sep 17 00:44:00 UTC 2015 PACKAGE (0.8.1) unstable; urgency=low * When ing a config file, ignore the attributes of elements contained therein. * Create correct SchemasTable entries for tables whose full names (with prefix and all suffixes) exceed 63 characters in length. * Eliminate some compiler warnings. * Funnel all use of write() to a single WriteWithNewline() function that checks return status. -- Azure Linux Wed Sep 9 23:19:00 UTC 2015 PACKAGE (0.8.0) unstable; urgency=low * Add support for full "shoebox" rowkey schema via instanceID attribute on the element. * Minor corrections to config-file parse error messages. -- Azure Linux Fri Aug 28 00:37:45 UTC 2015 PACKAGE (0.7.10) unstable; urgency=low * Replace sprintf with snprintf * Remove execvp from daemon execution; daemon is no longer started via a second command line invocation. * Add secure compilation options (PIC/PIE, stack protection, immediate binding) to mdsd and autokey. -- Azure Linux Tue Aug 18 11:10:01 UTC 2015 PACKAGE (0.7.9) unstable; urgency=low * Enforce actual XTable limits (column size, total row size). * Fixed a rare case that could result in events being uploaded twice. * Expunge expired tags on an open ingest connection even if no new events are arriving on that connection. * Add dupeWindowSeconds to element; specifies the time window during which duplicate events must be detected. Min 60; max 3600. * Fixed a rare problem in which receipt of a partial JSON event from a sender corrupts the reassembly buffer during buffer expansion. -- Azure Linux Fri Jun 26 23:18:01 UTC 2015 PACKAGE (0.7.8) unstable; urgency=low * Close event-ingest connection if sync is lost while trying to find JSON. Event senders are required to detect the closed connection and resend any event that was not acknowledged (i.e. for which it did not see the TAG echoed back on the connection). -- Azure Linux Fri Jun 19 23:18:01 UTC 2015 PACKAGE (0.7.7) unstable; urgency=low * Fix a regression in the generation of the MDS Table Search schema. -- Azure Linux Wed Jun 10 23:52:26 UTC 2015 PACKAGE (0.7.6) unstable; urgency=low * Flush unneeded event data from local tables held in memory. * Fix a crash (SIGSEGV) on process exit (seen only when using the -v option). -- Azure Linux Fri May 18 01:55:45 UTC 2015 PACKAGE (0.7.5) unstable; urgency=low * Add scaleUp and scaleDown attributes to to scale specific values retrieved from OMI and unpivoted. transforms more than just the column name and is thus somewhat misnamed, but changing it is a breaking schema change. -- Azure Linux Fri May 08 01:55:45 UTC 2015 PACKAGE (0.7.4) unstable; urgency=low * Build against libazurestorage 1.0.0. Stop suppressing certain warnings during build. -- Azure Linux Fri May 01 01:55:45 UTC 2015 PACKAGE (0.7.3) unstable; urgency=low * Don't emit an error message when creating a missing table. -- Azure Linux Sat Apr 25 01:55:45 UTC 2015 PACKAGE (0.7.2) unstable; urgency=low * Force mt_int32 values to remain 32 bits and not scaled up by the storage API. * Store TIMESTAMP as a true DateTime (as implemented in PPLX utility::datetime and as expected by the storage SDK). * RowKey for the LAD Query has metric name and timestamp separated by only two underscores, not the three that MDS uses when combining strings. Also, the hex expansions of non-alphanumerics are expected to use all uppercase hex digits. -- Azure Linux Sat Apr 25 01:55:45 UTC 2015 PACKAGE (0.7.0) unstable; urgency=low * Store events to unpersisted local tables (storeType="local"). * element within will unpivot specified columns into separate rows. * within will rename specific unpivoted datum names (e.g. change "AvailableMemory" to "MEMORY\Available"). * and enable querying of data from a local table to produce a set of aggregates. The query is fixed to meet the needs of Linux Azure Diagnostics and includes customized partition and row keys. * Create missing tables if full storage account credentials are suppplied. * Enable mocking of MDS through storing events in a disk file (storeType="file"). -- Azure Linux Wed Apr 22 01:55:45 UTC 2015 ================================================ FILE: Diagnostic/mdsd/lad-mdsd/copyright ================================================ Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Source: https://msazure.visualstudio.com/One/_git/Compute-Runtime-Tux/ Files: * Copyright: 2015 Microsoft Corporation License: Microsoft Internal Use ONLY ================================================ FILE: Diagnostic/mdsd/lad-mdsd/deb/Makefile ================================================ include ../Makefile.in.version VERSION=${VERSION_NUM} PACKAGE=lad-mdsd LABEL?=~dev ARCH?=amd64 VER=$(VERSION)-$(LABEL) FAKEROOT=./data-root DOCDIR=$(FAKEROOT)/usr/share/doc/$(PACKAGE) SHAREDIR=$(FAKEROOT)/usr/share/$(PACKAGE) MDSD_BIN_DIR=$(FAKEROOT)/usr/local/lad/bin MDSD_BUILT_BIN=../../builddir/release/bin/mdsd DEB=$(PACKAGE)-$(VER).$(ARCH).deb package: $(DEB) signed-package: _gpgorigin $(DEB) ar r $(DEB) $< _gpgorigin: $(DEB) -rm -f $@ ar p $(DEB) debian-binary control.tar.gz data.tar.gz | gpg -abs -o _gpgorigin $(DEB): tarballs debian-binary -rm -f $@ ar rc $@ debian-binary control.tar.gz data.tar.gz $(DOCDIR): mkdir -p $@ $(DOCDIR)/changelog.Debian.gz: ../changelog $(DOCDIR) cat $< | gzip -9 > $@ $(DOCDIR)/copyright: ../copyright $(DOCDIR) cp $< $@ debian-binary: echo 2.0 > debian-binary tarballs: data.tar.gz control.tar.gz control.tar.gz: md5sums control -rm -rf control-root -mkdir -p control-root cp control md5sums control-root chmod 644 control-root/* sed -i '/^Version:/c Version: $(VER)' control-root/control sed -i '/^Package:/c Package: $(PACKAGE)' control-root/control sed -i '/^Architecture:/c Architecture: $(ARCH)' control-root/control cd control-root && tar -czf ../$@ --owner=root --group=root . md5sums: install-deps (cd $(FAKEROOT) && md5sum `find -type f`) > $@ chmod 0644 $@ data.tar.gz: install-deps \ $(DOCDIR)/changelog.Debian.gz \ $(DOCDIR)/copyright \ $(LINTIANOVERRIDES) find $(FAKEROOT) -type d | xargs chmod 0755 find $(FAKEROOT) -type d | xargs chmod ug-s find $(FAKEROOT)/usr/share/doc -type f | xargs chmod 0644 cd $(FAKEROOT) && tar -czf ../$@ --owner=root --group=root --mode=go-w * .PHONY: clean install-clean install-deps clean: install-clean -rm -rf control-root -rm -f debian-binary *.tar.gz _gpgorigin md5sums -rm -f $(PACKAGE)*.deb install-clean: -rm -rf $(FAKEROOT) install-deps: install-clean mkdir -p $(MDSD_BIN_DIR) install -m 755 $(MDSD_BUILT_BIN) $(MDSD_BIN_DIR)/mdsd ================================================ FILE: Diagnostic/mdsd/lad-mdsd/deb/control ================================================ Package: PACKAGE Version: VERSION Section: admin Priority: optional Architecture: ARCH Depends: libc6, scx (>=1.6.2.169), omi, omsagent Maintainer: Azure Linux Team Description: MDS monitoring agent daemon for Linux Azure Diagnostic extension MDS monitoring daemon for Linux Azure Diagnostic extension ================================================ FILE: Diagnostic/mdsd/lad-mdsd/rpm/Makefile ================================================ include ../Makefile.in.version VERSION=${VERSION_NUM} PACKAGE=lad-mdsd LABEL?=dev DATAROOT=./data-root FAKEROOT=$(DATAROOT)/$(PACKAGE)-$(VERSION) DOCDIR=$(FAKEROOT)/usr/share/doc/$(PACKAGE) SHAREDIR=$(FAKEROOT)/usr/share/$(PACKAGE) MDSD_BIN_DIR=$(FAKEROOT)/usr/local/lad/bin MDSD_BUILT_BIN=../../builddir/release/bin/mdsd RPM=RPMS/x86_64/$(PACKAGE)-$(VERSION)-$(LABEL).x86_64.rpm TARBALL=$(PACKAGE)-$(VERSION).tgz RPM: $(TARBALL) rpmbuild -v -bb --clean --define "_topdir $(realpath .)" SPECS/lad-mdsd.spec $(TARBALL): rpm_prepare install-deps $(DOCDIR)/ChangeLog find $(FAKEROOT) -type d | xargs chmod 0755 find $(FAKEROOT) -type d | xargs chmod ug-s cd $(DATAROOT) && tar -czf ../SOURCES/$@ * $(DOCDIR): mkdir -p $@ $(DOCDIR)/ChangeLog: ../changelog $(DOCDIR) cp $< $@ rpm_prepare: clean mkdir -p SOURCES SPECS BUILD BUILDROOT RPMS SRPMS cp lad-mdsd.spec SPECS sed -i '/^Name:/c Name: $(PACKAGE)' SPECS/lad-mdsd.spec sed -i '/^Version:/c Version: $(VERSION)' SPECS/lad-mdsd.spec sed -i '/^Release:/c Release: $(LABEL)' SPECS/lad-mdsd.spec .PHONY: clean install-clean install-deps clean: install-clean -rm -rf SOURCES SPECS BUILD BUILDROOT RPMS SRPMS install-clean: -rm -rf $(DATAROOT) install-deps: install-clean mkdir -p $(MDSD_BIN_DIR) install -m 755 $(MDSD_BUILT_BIN) $(MDSD_BIN_DIR)/mdsd ================================================ FILE: Diagnostic/mdsd/mdscommands/BinaryWriter.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __BINARYWRITER__HH__ #define __BINARYWRITER__HH__ #include #include #include #include #include "MdsException.hh" namespace mdsd { namespace details { typedef uint8_t byte; // A helper class which allows write data in binary format to the bytes buffer. class BinaryWriter { template class BinaryWriterFunctions { public: static void Write(BinaryWriter& writer, T value); static void Write(BinaryWriter& writer, size_t position, T value); }; template class BinaryWriterFunctions { public: static void Write(BinaryWriter& writer, T value) { writer.Write(reinterpret_cast(&value), sizeof(T)); } static void Write(BinaryWriter& writer, size_t position, T value) { writer.Write(position, reinterpret_cast(&value), sizeof(T)); } }; template class BinaryWriterFunctions { static void Write(BinaryWriter& writer, T value); static void Write(BinaryWriter& writer, size_t position, T value); }; public: // Initializes a BinaryWriter object specifying the buffer to be used. BinaryWriter(std::vector& buffer) : m_buffer(buffer) {} // Gets the current size of the buffer. std::size_t GetBufferSize() const { return m_buffer.size(); } // Writes binary data to the specified position of the buffer, extending it if required. void Write(size_t position, const byte* source, size_t sourceSize) { if (!source) { throw MDSEXCEPTION("Unexpected NULL for source pointer."); } if (position + sourceSize > m_buffer.size()) { m_buffer.resize(position + sourceSize); } memcpy(m_buffer.data() + position, source, sourceSize); } // Writes binary data to the end of the buffer, extending it. void Write(const byte* source, size_t sourceSize) { if (!source) { throw MDSEXCEPTION("Unexpected NULL for source pointer."); } Write(m_buffer.size(), source, sourceSize); } // Writes value of the primitive type to the end of the buffer in binary format. template void Write(T value) { BinaryWriterFunctions::value>::Write(*this, value); } // Writes value of the primitive type to the specified position of the buffer in binary format. template void Write(size_t position, T value) { BinaryWriterFunctions::value>::Write(*this, position, value); } // Writes string value to the end of the buffer. void Write(const std::string & value) { Write(reinterpret_cast(value.c_str()), value.size()); } // Writes an integer value to the end of the buffer in base-128 format. void WriteInt32AsBase128(int value) { WriteInt64AsBase128(value); } // Writes an int64 value to the end of the buffer in base-128 format. void WriteInt64AsBase128(int64_t value) { bool negative = value < 0; long t = static_cast(negative ? -value : value); bool first = true; do { byte b; if (first) { b = (byte)(t & 0x3f); t >>= 6; if (negative) { b = (byte)(b | 0x40); } first = false; } else { b = (byte)(t & 0x7f); t >>= 7; } if (t > 0) { b |= 0x80; } Write(&b, sizeof(b)); } while (t > 0); } // Writes an unsigned integer value to the end of the buffer in base-128 format. void WriteUInt32AsBase128(unsigned int value) { WriteUInt64AsBase128(value); } // Writes an unsigned long value to the end of the buffer in base-128 format. void WriteUInt64AsBase128(uint64_t value) { uint64_t t = value; do { byte b = (byte)(t & 0x7f); t >>= 7; if (t > 0) { b |= 0x80; } Write(&b, sizeof(b)); } while (t > 0); } // Clears the buffer. void Reset() { m_buffer.clear(); } private: std::vector& m_buffer; }; } // namespace details } // namespace mdsd #endif // __BINARYWRITER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/BodyOnlyXmlParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include "BodyOnlyXmlParser.hh" #include "MdsException.hh" using namespace mdsd::details; void BodyOnlyXmlParser::ParseFile(std::string xmlFilePath) { m_xmlFilePath = std::move(xmlFilePath); std::ifstream infile{m_xmlFilePath}; if (!infile) { std::ostringstream strm; strm << "Failed to open file '" << m_xmlFilePath << "'."; throw MDSEXCEPTION(strm.str()); } std::string line; while(std::getline(infile, line)) { ParseChunk(line); } if (!infile.eof()) { std::ostringstream strm; strm << "Failed to parse file '" << m_xmlFilePath << "': "; if (infile.bad()) { strm << "Corrupted stream."; } else if (infile.fail()) { strm << "IO operation failed."; } else { strm << "std::getline() returned 0 for unknown reason."; } throw MDSEXCEPTION(strm.str()); } } void BodyOnlyXmlParser::OnCharacters(const std::string& chars) { bool isEmptyOrWhiteSpace = std::all_of(chars.cbegin(), chars.cend(), ::isspace); if (!isEmptyOrWhiteSpace) { m_body.append(chars); } } ================================================ FILE: Diagnostic/mdsd/mdscommands/BodyOnlyXmlParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __BODYONLYXMLPARSER__HH__ #define __BODYONLYXMLPARSER__HH__ #include #include #include "SaxParserBase.hh" namespace mdsd { namespace details { /// /// This is a simple XML parser. It will parse the XML body section only. /// The XML attributes are not parsed. /// class BodyOnlyXmlParser : public SaxParserBase { public: BodyOnlyXmlParser() = default; ~BodyOnlyXmlParser() = default; /// Parse given xml file virtual void ParseFile(std::string xmlFilePath); std::string&& MoveBody() { return std::move(m_body); } virtual std::string GetFilePath() const { return m_xmlFilePath; } private: void OnStartElement(const std::string& name, const AttributeMap& attributes) override { m_body.clear(); } void OnEndElement(const std::string& name) override {} void OnCharacters(const std::string& chars) override; void OnCDataBlock(const std::string& text) override { m_body.append(text); } private: std::string m_xmlFilePath; std::string m_body; }; } // namespace details } // namespace mdsd #endif // __BODYONLYXMLPARSER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/CMakeLists.txt ================================================ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") include_directories( ${CASABLANCA_INCLUDE_DIRS} ${STORAGE_INCLUDE_DIRS} /usr/include/libxml2 ${CMAKE_SOURCE_DIR}/mdsd ${CMAKE_SOURCE_DIR}/mdsdlog ${CMAKE_SOURCE_DIR}/mdsdutil ) set(SOURCES BodyOnlyXmlParser.cc CmdListXmlParser.cc CmdXmlCommon.cc CmdXmlElement.cc CmdXmlParser.cc ConfigUpdateCmd.cc DirectoryIter.cc EventData.cc EventEntry.cc EventHubCmd.cc EventHubPublisher.cc EventHubType.cc EventHubUploader.cc EventHubUploaderId.cc EventHubUploaderMgr.cc EventPersistMgr.cc MdsBlobReader.cc MdsException.cc PersistFiles.cc PublisherStatus.cc ${CMAKE_SOURCE_DIR}/mdsd/SaxParserBase.cc ) # Disable warning from CPPREST set_source_files_properties(PersistFiles.cc PROPERTIES COMPILE_FLAGS -Wno-sign-compare) # Disable warnings from azure storage API. set_source_files_properties( MdsBlobReader.cc EventHubCmd.cc PROPERTIES COMPILE_FLAGS "-Wno-unused-value -Wno-reorder" ) add_library(${CMD_LIB_NAME} STATIC ${SOURCES}) install(TARGETS ${CMD_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_BINARY_DIR}/release/lib ) ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdListXmlParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include "CmdListXmlParser.hh" #include "MdsException.hh" #include "CmdXmlElement.hh" using namespace mdsd::details; void CmdListXmlParser::OnEndElement(const std::string& name) { switch(Name2ElementType(name)) { case ElementType::Verb: m_verb = MoveBody(); break; case ElementType::Parameter: m_paramList.emplace_back(MoveBody()); break; case ElementType::Command: if (std::all_of(m_verb.cbegin(), m_verb.cend(), ::isspace)) { std::ostringstream strm; strm << "Invalid data in XML file '" << GetFilePath() << "': 'Verb' cannot be empty or whitespace."; throw MDSEXCEPTION(strm.str()); } if (0 == m_paramList.size()) { std::ostringstream strm; strm << "Invalid data in XML file '" << GetFilePath() << "': no Parameter value is found."; throw MDSEXCEPTION(strm.str()); } m_cmdParamMap[m_verb].emplace_back(m_paramList); m_paramList.clear(); break; default: break; } } ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdListXmlParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __CMDLISTXMLPARSER__HH__ #define __CMDLISTXMLPARSER__HH__ #include #include #include "BodyOnlyXmlParser.hh" namespace mdsd { namespace details { /// /// Commands XML parser. It will parse .... /// For reference, check commands.xsd. /// class CmdListXmlParser : public BodyOnlyXmlParser { public: /// map key: Verb name. map value: list of parameter-list. using CmdParamsType = std::unordered_map>>; CmdListXmlParser() = default; ~CmdListXmlParser() = default; CmdParamsType GetCmdParams() const { return m_cmdParamMap; } private: void OnEndElement(const std::string& name) override; private: CmdParamsType m_cmdParamMap; // store all verb names and all parameters. std::string m_verb; // store current verb name in the parser. std::vector m_paramList; // store current parameter list in the parser. }; } // namespace details } // namespace mdsd #endif // __CMDLISTXMLPARSER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlCommon.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "CmdXmlCommon.hh" #include "MdsException.hh" namespace mdsd { std::string CmdXmlCommon::s_rootContainerName = "mam"; namespace details { void ValidateCmdBlobParamsList( const std::vector>& paramsList, const std::string & verbName, size_t totalParams ) { if (0 == paramsList.size()) { std::ostringstream strm; strm << "No Command Parameter is found for Verb '" << verbName << "'."; throw MDSEXCEPTION(strm.str()); } for (const auto & v : paramsList) { if (totalParams != v.size()) { std::ostringstream strm; strm << "Invalid number of Command (verb=" << verbName << ") parameters: expected=" << totalParams << "; actual=" << v.size() << "."; throw MDSEXCEPTION(strm.str()); } } } } // namespace details } // namespace mdsd ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlCommon.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __CMDXMLCOMMON_HH__ #define __CMDXMLCOMMON_HH__ #include #include namespace mdsd { class CmdXmlCommon { public: static std::string GetRootContainerName() { return s_rootContainerName; } static void SetRootContainerName(std::string name) { s_rootContainerName = std::move(name); } private: static std::string s_rootContainerName; }; namespace details { void ValidateCmdBlobParamsList( const std::vector>& paramsList, const std::string & verbName, size_t totalParams ); } // namespace details } // namespace mdsd #endif // __CMDXMLCOMMON_HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlElement.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CmdXmlElement.hh" #include using namespace mdsd::details; static std::unordered_map & GetCmdElementTypeMap() { static auto xmltable = new std::unordered_map( { { "Verb", ElementType::Verb }, { "Parameter", ElementType::Parameter }, { "Command", ElementType::Command } }); return *xmltable; } ElementType mdsd::details::Name2ElementType(const std::string& name) { auto xmltable = GetCmdElementTypeMap(); auto iter = xmltable.find(name); if (iter != xmltable.end()) { return iter->second; } return ElementType::Unknown; } ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlElement.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __CMDXMLELEMENT__HH__ #define __CMDXMLELEMENT__HH__ #include namespace mdsd { namespace details { enum class ElementType { Unknown, Verb, Parameter, Command }; ElementType Name2ElementType(const std::string& name); } // namespace details } // namespace mdsd #endif // __CMDXMLELEMENT__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CmdXmlParser.hh" #include "CmdXmlElement.hh" using namespace mdsd::details; void CmdXmlParser::OnEndElement(const std::string& name) { switch(Name2ElementType(name)) { case ElementType::Verb: m_verb = MoveBody(); break; case ElementType::Parameter: m_paramList.emplace_back(MoveBody()); break; default: break; } } ================================================ FILE: Diagnostic/mdsd/mdscommands/CmdXmlParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __CMDXMLPARSER__HH__ #define __CMDXMLPARSER__HH__ #include #include "BodyOnlyXmlParser.hh" namespace mdsd { namespace details { /// /// MDS Command XML parser. It will parse one ... /// For reference, check commands.xsd. /// class CmdXmlParser : public BodyOnlyXmlParser { public: CmdXmlParser() = default; ~CmdXmlParser() = default; std::string GetVerb() const { return m_verb; } std::vector GetParamList() const { return m_paramList; } private: void OnEndElement(const std::string& name) override; private: std::string m_verb; // The value of 'Verb' std::vector m_paramList; // a list of the parameters defined for the Verb. }; } // namespace details } // namespace mdsd #endif // __CMDXMLPARSER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/ConfigUpdateCmd.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include "ConfigUpdateCmd.hh" #include "MdsBlobReader.hh" #include "CmdListXmlParser.hh" #include "CmdXmlCommon.hh" #include "MdsException.hh" #include "Trace.hh" #include "Logger.hh" #include "Crypto.hh" using namespace mdsd; using namespace mdsd::details; uint64_t ConfigUpdateCmd::s_lastTimestamp = 0; Crypto::MD5Hash ConfigUpdateCmd::s_lastMd5Sum; std::string ConfigUpdateCmd::s_cmdFileName = "MACommandCu.xml"; ConfigUpdateCmd::ConfigUpdateCmd( const std::string& rootContainerSas, const std::string& eventNameSpace, const std::string& tenantName, const std::string& roleName, const std::string& instanceName) : m_rootContainerSas(rootContainerSas) , m_configXmlPersistentFlag(true) // Just to avoid IDE/compiler warning { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::ConfigUpdateCmd"); if (rootContainerSas.empty()) { throw MDSEXCEPTION("ConfigUpdate blob root container cannot be empty."); } if (eventNameSpace.empty()) { throw MDSEXCEPTION("ConfigUpdate MDS namespace cannot be empty."); } // Check the validity of tenantName, roleName & instanceName. // 1. if tenantName is empty, then both roleName & instanceName must be empty if (tenantName.empty() && !(roleName.empty() && instanceName.empty())) { throw MDSEXCEPTION("Non-empty role name or instance name when tenant name is empty."); } // 2. if roleName is empty, then instance name must be empty if (roleName.empty() && !instanceName.empty()) { throw MDSEXCEPTION("Non-empty instanceName given when roleName is empty."); } // Construct the list of all possible cmd xml paths in xstore. // E.g., "TuxTest/myTestTenant/role1/instance1/MACommandCu.xml", // "TuxTest/myTestTenant/role1/MACommandCu.xml", // "TuxTest/myTestTenant/MACommandCu.xml" and // "TuxTest/MACommandCu.xml" std::string upToNameSpace = eventNameSpace + "/"; std::string upToTenantName = upToNameSpace + tenantName + "/"; std::string upToRoleName = upToTenantName + roleName + "/"; m_cmdXmlPathsXstore.reserve(4); // Maximum 4 paths to try if (!instanceName.empty()) { m_cmdXmlPathsXstore.push_back(upToRoleName + instanceName + "/" + s_cmdFileName); } if (!roleName.empty()) { m_cmdXmlPathsXstore.push_back(upToRoleName + s_cmdFileName); } if (!tenantName.empty()) { m_cmdXmlPathsXstore.push_back(upToTenantName + s_cmdFileName); } // Namespace/MACommandCu.xml should be always added m_cmdXmlPathsXstore.push_back(upToNameSpace + s_cmdFileName); TRACEINFO(trace, "ConfigUpdateCmd::ConfigUpdateCmd(), namespace = \"" << eventNameSpace << "\", tenantName = \"" << tenantName << "\", roleName = \"" << roleName << "\", instanceName = \"" << instanceName << "\", resulting cmd xml path in xstore (longest one only) = \"" << m_cmdXmlPathsXstore.front() << '"'); } // Helper for parsing config update cmd xml static bool ParseConfigUpdateCmdXml( std::string&& xmlDoc, bool& configXmlPersistentFlag, Crypto::MD5Hash& configXmlMD5Sum, std::string& configXmlPathXstore) { Trace trace(Trace::MdsCmd, "ParseConfigUpdateCmdXml"); if (xmlDoc.empty()) { trace.NOTE("No ConfigUpdate cmd XML data to parse. Abort parser."); return false; } configXmlPersistentFlag = false; configXmlPathXstore.clear(); CmdListXmlParser parser; parser.Parse(xmlDoc); auto paramTable = parser.GetCmdParams(); if (0 == paramTable.size()) { throw MDSEXCEPTION("No Command Parameter is found in ConfigUpdate cmd XML."); } // UpdateConfig cmd xml example: // // // UpdateConfig // // TRUE // 65db3091d1b6ba83c7dba7a9a1a984ce // ConfigArchive/65db3091d1b6ba83c7dba7a9a1a984ce/TuxTestVer7v0.xml // // const std::string CfgUpdateCmdVerb = "UpdateConfig"; const auto NPARAMS = 3; const auto PersistentFlagIndex = 0; const auto ConfigXmlMD5SumIndex = 1; const auto ConfigXmlXstorePathIndex = 2; auto cfgUpdateParamsList = paramTable[CfgUpdateCmdVerb]; ValidateCmdBlobParamsList(cfgUpdateParamsList, CfgUpdateCmdVerb, NPARAMS); // Now extract the parameters // But check if there are more than one UpdateConfig commands in the cmd xml. // In that case, log a warning and use the last one. if (cfgUpdateParamsList.size() > 1) { std::ostringstream msg; msg << "More than one UpdateConfig commands given in the cmd XML" << " (there were " << cfgUpdateParamsList.size() << "). Only the last one will be used."; Logger::LogWarn(msg); } const auto& params = cfgUpdateParamsList.back(); configXmlPersistentFlag = params[PersistentFlagIndex] == "TRUE"; configXmlMD5Sum = Crypto::MD5Hash::from_hash(params[ConfigXmlMD5SumIndex]); configXmlPathXstore = std::move(params[ConfigXmlXstorePathIndex]); TRACEINFO(trace, "MDS config update cmd xml blob parsed. persist flag = " << configXmlPersistentFlag << ", config xml md5sum = " << configXmlMD5Sum.to_string() << ", config xml xstore path = " << configXmlPathXstore); return true; } pplx::task ConfigUpdateCmd::StartAsyncDownloadOfNewConfig() { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::StartAsyncDownloadOfNewConfig"); // Helper struct type to hold a cml blob path and its LMT struct LmtLookupDataT { const std::string* m_cmdXmlPath; uint64_t m_lmt; LmtLookupDataT(const std::string& cmdXmlPath, uint64_t lmt) : m_cmdXmlPath(&cmdXmlPath) , m_lmt(lmt) {} // Just for containers LmtLookupDataT() : m_cmdXmlPath(nullptr), m_lmt(0) {} bool operator<(const LmtLookupDataT& rhs) const { return m_lmt < rhs.m_lmt; } }; std::vector> lmtTasks; // Parallel LMT lookup tasks // Async/parallel LMT retrieval for (size_t i = 0; i < m_cmdXmlPathsXstore.size(); i++) { lmtTasks.push_back(pplx::task([=]() { MdsBlobReader blobReader(m_rootContainerSas, m_cmdXmlPathsXstore[i]); // Get the blob's LMT along with the blob's path (asynchronously) auto asyncLmtLookupTask = blobReader.GetLastModifiedTimeStampAsync( MdsBlobReader::DoNothingBlobNotFoundExHandler); // We don't want to log non-existing blob here, as that could be frequent and persistent return asyncLmtLookupTask.then([=](uint64_t lmt) { return LmtLookupDataT(m_cmdXmlPathsXstore[i], lmt); }); })); } // Specify what to do when all parallel tasks are completed return pplx::when_all(lmtTasks.begin(), lmtTasks.end()).then([=](std::vector lmtResults) -> pplx::task { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::StartAsyncDownloadOfNewConfig when_all().then() lambda"); // Find latest LMT path auto maxLmtResult = std::max_element(lmtResults.begin(), lmtResults.end()); auto latestLmt = maxLmtResult->m_lmt; auto latestLmtCmdXmlPath = *maxLmtResult->m_cmdXmlPath; TRACEINFO(trace, "Latest LMT from all candidate cmd blob paths (# paths: " << m_cmdXmlPathsXstore.size() << ", longest path: " << m_cmdXmlPathsXstore.front() << ", latest LMT path: " << latestLmtCmdXmlPath << ") = " << latestLmt << " (0 means no cmd blob found), " << ", s_lastTimestamp = " << s_lastTimestamp); return GetCmdXmlAsync(latestLmt, latestLmtCmdXmlPath); }).then([](bool result) { return result; }); } pplx::task ConfigUpdateCmd::GetCmdXmlAsync(uint64_t blobLmt, std::string cmdXmlPathXstore) { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::GetCmdXmlAsync"); pplx::task returnFalseTask([]() { return false; }); if (blobLmt == 0) // No cmd blob found. Nothing to do. { TRACEINFO(trace, "No cmd blob was passed (blobLmt = 0). Nothing to do."); return returnFalseTask; } if (blobLmt <= s_lastTimestamp) // No new cmd blob found. Nothing to do. { TRACEINFO(trace, "No new cmd blob was passed (passed blobLmt = " << blobLmt << ", s_lastTimestamp = " << s_lastTimestamp << '"'); return returnFalseTask; } // Get/check the cmd blob's content MdsBlobReader cmdXmlBlobReader(m_rootContainerSas, cmdXmlPathXstore); auto asyncCmdXmlReadTask = cmdXmlBlobReader.ReadBlobToStringAsync(); return asyncCmdXmlReadTask.then([blobLmt,this](std::string cmdXmlString) -> pplx::task { return ProcessCmdXmlAsync(blobLmt, std::move(cmdXmlString)); }); } pplx::task ConfigUpdateCmd::ProcessCmdXmlAsync(uint64_t blobLmt, std::string cmdXmlString) { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::ProcessCmdXmlAsync"); TRACEINFO(trace, "Cmd XML Blob content=\"" << cmdXmlString << '"'); pplx::task returnFalseTask([]() { return false; }); if (cmdXmlString.empty()) // Cmd blob content is empty. Nothing to do. { return returnFalseTask; } bool configXmlPersistentFlag = false; Crypto::MD5Hash configXmlMD5Sum; std::string configXmlPathXstore; std::string genevaIssueMsg = "[Geneva has generated an invalid configuration update command--See the description outside the bracket. Please report this via the 'Contact Us' button on the Geneva Monitoring portal] "; try { if (!ParseConfigUpdateCmdXml(std::move(cmdXmlString), configXmlPersistentFlag, configXmlMD5Sum, configXmlPathXstore)) { return returnFalseTask; } } catch (const MdsException& e) { std::ostringstream msg; msg << genevaIssueMsg << "ConfigUpdate cmd XML parse failed (no UpdateConfig verb or invalid XML format): " << e.what(); Logger::LogError(msg); return returnFalseTask; } // Validate the retrieved ConfigUpdate cmd params if (configXmlPathXstore.empty()) { Logger::LogError(genevaIssueMsg + "ConfigUpdate cmd's config xml xstore path param cannot be empty."); return returnFalseTask; } TRACEINFO(trace, "Cmd XML parsed successfully. ConfigXml xstore path = " << configXmlPathXstore << ", MD5 sum = " << configXmlMD5Sum.to_string() << ", persistent flag = " << configXmlPersistentFlag); // Check if the md5 is the same as the last downloaded one, and return if so. if (configXmlMD5Sum == s_lastMd5Sum) { TRACEINFO(trace, "MD5 sum given in the cmd XML" << " is equal to the last downloaded one. Skipping this one."); return returnFalseTask; } // Now, download config XML from Xstore (asynchronously) MdsBlobReader blobReader(m_rootContainerSas, configXmlPathXstore); auto cfgXmlAsyncReadTask = blobReader.ReadBlobToStringAsync(); return cfgXmlAsyncReadTask.then([=](std::string configXml) -> pplx::task { return GetCfgXmlAsync(std::move(configXml), configXmlMD5Sum, configXmlPathXstore, configXmlPersistentFlag, blobLmt); }); } pplx::task ConfigUpdateCmd::GetCfgXmlAsync( std::string && configXml, const Crypto::MD5Hash & configXmlMD5Sum, const std::string & configXmlPathXstore, bool configXmlPersistentFlag, uint64_t cmdBlobLmt) { Trace trace(Trace::MdsCmd, "ConfigUpdateCmd::GetCfgXmlAsync"); TRACEINFO(trace, "Downloaded mdsd cfg xml: \"" << configXml << '"'); pplx::task returnFalseTask([]() { return false; }); if (configXml.empty()) { Logger::LogError("Downloaded mdsd cfg xml is empty!"); return returnFalseTask; } // Check if md5 sum matches the passed md5sum param auto computedMD5Sum = Crypto::MD5HashString(configXml); if (configXmlMD5Sum != computedMD5Sum) { std::ostringstream msg; msg << "MD5 sum mismatch! Calculated = " << computedMD5Sum.to_string() << ", Given in cmd XML = " << configXmlMD5Sum.to_string(); Logger::LogError(msg); return returnFalseTask; } // Now update the relevant member variables m_configXmlPathXstore = configXmlPathXstore; m_configXmlString = std::move(configXml); m_configXmlMD5Sum = std::move(computedMD5Sum); m_configXmlPersistentFlag = configXmlPersistentFlag; s_lastMd5Sum = computedMD5Sum; s_lastTimestamp = cmdBlobLmt; return pplx::task([](){ return true; }); } ================================================ FILE: Diagnostic/mdsd/mdscommands/ConfigUpdateCmd.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __CONFIGUPDATECMD_HH__ #define __CONFIGUPDATECMD_HH__ #include #include #include "Crypto.hh" namespace mdsd { /// /// This class implements functions to handle ConfigUpdate command xml files. /// This includes download xml file, parse xml file, and get data from xml. /// class ConfigUpdateCmd { public: /// /// Create the object that'll handle a ConfigUpdate command xml file. /// The sas key for the root container /// where the command xml file locates. /// Event namespace (e.g., TuxTest). Can't be empty. /// Tenant name. Optional /// Role name. Optional /// Instance name. Optional /// ConfigUpdateCmd( const std::string& rootContainerSas, const std::string& eventNameSpace, const std::string& tenantName, const std::string& roleName, const std::string& instanceName); ~ConfigUpdateCmd() {} ConfigUpdateCmd(const ConfigUpdateCmd & other) = default; ConfigUpdateCmd(ConfigUpdateCmd&& other) = default; ConfigUpdateCmd& operator=(const ConfigUpdateCmd& other) = default; ConfigUpdateCmd& operator=(ConfigUpdateCmd&& other) = default; /// /// Initiate an async download of a new config. Returns a task whose result /// is true iff a new config was successfully downloaded (and corresponding /// member variables are correctly updated). /// pplx::task StartAsyncDownloadOfNewConfig(); /// /// Get the config XML string downloaded from XStore /// std::string GetConfigXmlString() const { return m_configXmlString; } /// /// Get the config XML string's MD5 sum /// Crypto::MD5Hash GetConfigXmlMD5Sum() const { return m_configXmlMD5Sum; } /// /// Initialize with existing MD5Hash (e.g. from the mdsd command line config). /// static void Initialize(const Crypto::MD5Hash& md5) { s_lastMd5Sum = md5; } private: std::string m_rootContainerSas; std::string m_configXmlString; // Member variable where downloaded mdsd config xml will be stored std::vector m_cmdXmlPathsXstore; // List of all XStore paths to search for a cmd xml blob. // e.g., "TuxTest/myTestTenant/role1/instance1/MACommandCu.xml", // "TuxTest/myTestTenant/role1/MACommandCu.xml", // "TuxTest/myTestTenant/MACommandCu.xml" // Function to asynchronously start downloading a cmd xml blob given as the param. // The task then continues to the ProcessCmdXmlAsync task if a cmd xml is downloaded correctly. // Returns the continuation task whose completion will give us the result of cmd blob downloading/processing. pplx::task GetCmdXmlAsync(uint64_t blobLmt, std::string cmdXmlPathXstore); // Async cmd XML processing task // The task then continues to the GetCfgXmlAsync task if a cmd xml is parsed correctly. pplx::task ProcessCmdXmlAsync(uint64_t blobLmt, std::string cmdXmlString); // Async cfg XML downloading task pplx::task GetCfgXmlAsync( std::string && configXml, const Crypto::MD5Hash & configXmlMD5Sum, const std::string & configXmlPathXstore, bool configXmlPersistentFlag, uint64_t blobLmt); // Extracted UpdateConfig cmd params std::string m_configXmlPathXstore; // e.g., "ConfigArchive/65db3091d1b6ba83c7dba7a9a1a984ce/TuxTestVer7v0.xml" Crypto::MD5Hash m_configXmlMD5Sum; // e.g., "65db3091d1b6ba83c7dba7a9a1a984ce" bool m_configXmlPersistentFlag; // May not be needed at all for us, but just saving it anyway // Things to remember for update logic // Updated with timestamp of the last successful XML cfg blob to compare with the new XML cfg blob static uint64_t s_lastTimestamp; // Updated with MD5 hash of the last successful mdsd config blob's MD5 sum to compare with the new XML cfg blob static Crypto::MD5Hash s_lastMd5Sum; // Fixed constants static std::string s_cmdFileName; // Currently "MACommandCu.xml" }; } // namespace mdsd #endif // __CONFIGUPDATECMD_HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/DirectoryIter.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include extern "C" { #include #include } #include "DirectoryIter.hh" #include "MdsException.hh" #include "MdsCmdLogger.hh" using namespace mdsd::details; DirectoryIter::DirectoryIter(): m_dirp(nullptr), m_result(nullptr) { memset(&m_ent, 0, sizeof(m_ent)); } DirectoryIter::DirectoryIter( const std::string & dirname): m_dirname(dirname), m_dirp(nullptr), m_result(nullptr) { m_dirp = opendir(dirname.c_str()); if (!m_dirp) { std::error_code ec(errno, std::system_category()); std::ostringstream strm; strm << "Failed to open directory '" << dirname << "'; Reason: " << ec.message(); throw MDSEXCEPTION(strm.str()); } MoveToNextValid(); } DirectoryIter::~DirectoryIter() { if (m_dirp) { closedir(m_dirp); } } void DirectoryIter::MoveToNext() { if (!m_dirp) { return; } auto rtn = readdir_r(m_dirp, &m_ent, &m_result); if (rtn) { std::ostringstream strm; strm << "Error: in directory iteration, readdir_r() failed with error code=" << rtn; MdsCmdLogError(strm); } if (!m_result) { memset(&m_ent, 0, sizeof(m_ent)); closedir(m_dirp); m_dirp = nullptr; m_result = nullptr; } } void DirectoryIter::MoveToNextValid() { while(true) { MoveToNext(); if (!m_dirp) { break; } std::string curdir{m_ent.d_name}; if ("." != curdir && ".." != curdir) { break; } } } DirectoryIter& DirectoryIter::operator++() { MoveToNextValid(); return *this; } std::string DirectoryIter::operator*() const { if (m_ent.d_name[0]) { return m_dirname + "/" + m_ent.d_name; } else { return std::string(); } } bool mdsd::details::operator==( const DirectoryIter& x, const DirectoryIter& y ) { return (x.m_dirp == y.m_dirp && x.m_result == y.m_result && strncmp(x.m_ent.d_name, y.m_ent.d_name, sizeof(x.m_ent.d_name)) == 0); } ================================================ FILE: Diagnostic/mdsd/mdscommands/DirectoryIter.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __DIRECTORYITER__HH__ #define __DIRECTORYITER__HH__ #include extern "C" { #include } namespace mdsd { namespace details { /// /// Iterator each entry in the directory, including sub-directories. /// It ignores "." and "..". /// class DirectoryIter { public: /// A directory iterator pointing to nothing DirectoryIter(); /// A directory iterator for given dir DirectoryIter(const std::string & dirname); ~DirectoryIter(); /// There is no safe way to copy 'DIR*'. Make class movable, not copyable. DirectoryIter(const DirectoryIter& other) = delete; DirectoryIter(DirectoryIter&& other) = default; DirectoryIter& operator=(const DirectoryIter& other) = delete; DirectoryIter& operator=(DirectoryIter&& other) = default; /// Pre-increment operator. Move to next entry in the directory. DirectoryIter& operator++(); /// Return current item name (filename or dir name) std::string operator*() const; /// Return whether 2 iter points to the same thing friend bool operator==(const DirectoryIter& x, const DirectoryIter& y); /// Return whether 2 iter points to different things friend bool operator!=(const DirectoryIter& x, const DirectoryIter& y) { return !(x==y); } private: void MoveToNext(); void MoveToNextValid(); private: std::string m_dirname; DIR* m_dirp; struct dirent m_ent; struct dirent * m_result; }; bool operator==(const DirectoryIter& x, const DirectoryIter& y); } // namespace details } // namespace mdsd #endif // __DIRECTORYITER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventData.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "EventData.hh" #include "MdsException.hh" using namespace mdsd; static std::string GetStringFromOutput( const bond::OutputBuffer & output ) { std::vector blist; output.GetBuffers(blist); size_t totalLen = 0; for (const auto & b : blist) { totalLen += b.length(); } std::string resultStr; resultStr.reserve(totalLen); for (const auto & b : blist) { resultStr.append(b.content(), b.length()); } return resultStr; } std::string EventDataT::Serialize() const { if (m_data.empty()) { throw MDSEXCEPTION("EventData serialization failed: data cannot be empty."); } bond::OutputBuffer output; bond::SimpleBinaryWriter writer(output); writer.Write(m_data); writer.Write(static_cast(m_table.size())); for (const auto & it : m_table) { writer.Write(it.first); writer.Write(it.second); } return GetStringFromOutput(output); } EventDataT EventDataT::Deserialize( const std::string & datastr ) { return Deserialize(datastr.c_str(), datastr.size()); } EventDataT EventDataT::Deserialize( const char* buf, size_t bufSize ) { if (!buf) { throw MDSEXCEPTION("EventData deserialization failed: input buf cannot be NULL."); } EventDataT dataObj; try { bond::blob b; b.assign(buf, bufSize); bond::SimpleBinaryReader reader(b); std::string datastr; reader.Read(datastr); dataObj.SetData(std::move(datastr)); size_t tblSize = 0; reader.Read(tblSize); for (size_t i = 0; i < tblSize; i++) { std::string k, v; reader.Read(k); reader.Read(v); dataObj.AddProperty(std::move(k), std::move(v)); } } catch(std::exception& ex) { throw MDSEXCEPTION(std::string("EventData deserialization failed: ") + ex.what()); } return dataObj; } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventData.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTDATA_HH__ #define __EVENTDATA_HH__ #include #include namespace mdsd { /// The EventDataT has 2 parts: a key-value pair table of properties and /// actual data string. class EventDataT { public: using EventPropertyT = std::unordered_map; EventDataT() = default; ~EventDataT() = default; bool empty() const { return m_data.empty() && m_table.empty(); } std::string GetData() const { return m_data; } void SetData(const std::string & data) { m_data = data; } void SetData(std::string && data) { m_data = std::move(data); } // Specialization for all integral types template typename std::enable_if::value, void>::type AddProperty(std::string name, T value) { m_table[std::move(name)] = std::to_string(value); } void AddProperty(std::string name, std::string value) { m_table[std::move(name)] = std::move(value); } // /// Get properties object which is [key,value] table. /// const EventPropertyT & Properties() const { return m_table; } std::string Serialize() const; static EventDataT Deserialize(const std::string & datastr); /// /// Deserialize a char array and return EventData object. /// The memory of the char array must be valid in this function. /// static EventDataT Deserialize(const char* buf, size_t bufSize); /// /// The max size of EventHub data to support. /// static size_t GetMaxSize() { return 256*1024; } private: EventPropertyT m_table; // {key,value} property table std::string m_data; // actual message data }; } // namespace mdsd #endif // __EVENTDATA_HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventEntry.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventEntry.hh" std::atomic mdsd::details::EventEntry::s_counter{0}; ================================================ FILE: Diagnostic/mdsd/mdscommands/EventEntry.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTENTRY_HH__ #define __EVENTENTRY_HH__ #include #include #include #include "EventData.hh" namespace mdsd { namespace details { /// /// EventEntry class include data sent to EventHub for each upload, /// plus some metadata about the event data. /// class EventEntry { public: EventEntry(const EventDataT & data) : m_rawData(data) { s_counter++; m_id = s_counter; } EventEntry(EventDataT && data) : m_rawData(std::move(data)) { s_counter++; m_id = s_counter; } ~EventEntry() {} EventEntry(const EventEntry& other) = default; EventEntry(EventEntry&& other) = default; EventEntry& operator=(const EventEntry& other) = default; EventEntry& operator=(EventEntry&& other) = default; /// Do exponential backoff for next retry void BackOff() { auto delta = m_nextSendTimet - m_firstSendTimet; if (0 == delta) { m_nextSendTimet++; } else { m_nextSendTimet = m_firstSendTimet + delta*2 + 1; } } bool IsNeverSent() const { return (0 == m_firstSendTimet); } void SetSendTime() { auto now = GetNow(); m_firstSendTimet = now; m_nextSendTimet = now; } /// /// Get number of seconds since the data was first uploaded. /// Return -1 if the data is never uploaded before. /// int32_t GetAgeInSeconds() const { if (0 == m_firstSendTimet) { return -1; } return (GetNow() - m_firstSendTimet); } EventDataT GetData() const { return m_rawData; } /// Get some ID for the event, for tracing purpose only. /// no need to be unique. uint64_t GetId() const { return m_id; } /// Is it now the time to re-upload the data? bool IsTimeToRetry() const { return (GetNow() >= m_nextSendTimet); } bool IsInPersistence() const { return m_inPersistence; } void SetPersistence() { m_inPersistence = true; } private: time_t GetNow() const { return time(nullptr); } private: // The minimum time to upload when getting a next chance. // If the current time is less than this value, data won't be uploaded. time_t m_nextSendTimet = 0; time_t m_firstSendTimet = 0; // The first time to upload the data. EventDataT m_rawData; // The raw data uploaded to Event Hub. static std::atomic s_counter; uint64_t m_id = 0; // A ID for the entry. For tracing purpose only. bool m_inPersistence = false; // Is the item added to persistence manager? }; } // namespace details } // namespace mdsd #endif // __EVENTENTRY_HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubCmd.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "EventHubCmd.hh" #include "MdsBlobReader.hh" #include "CmdListXmlParser.hh" #include "CmdXmlCommon.hh" #include "MdsException.hh" #include "Trace.hh" #include "Logger.hh" using namespace mdsd; using namespace mdsd::details; std::string EventHubCmd::s_parentContainerName = "mdssubscriptions"; std::ostream& operator<<(std::ostream& str, const EhCmdXmlItems & cmd) { // for security reason, only dump part of SAS key. str << "SAS key: " << cmd.sas.substr(0, 20) << "..., MDS Endpoint ID: " << cmd.endpoint << ", Mapped Moniker: " << cmd.moniker; return str; } EventHubCmd::EventHubCmd( std::string eventNameSpace, int eventVersion, std::string rootContainerSas ) : m_blobNameSuffix(std::move(eventNameSpace)), m_rootContainerSas(std::move(rootContainerSas)), m_noticeXmlItemsTable(new EhCmdXmlItemsTable_t()), m_pubXmlItemsTable(new EhCmdXmlItemsTable_t()) { if (m_blobNameSuffix.empty()) { throw MDSEXCEPTION("Event Hub MDS namespace cannot be empty."); } if (m_rootContainerSas.empty()) { throw MDSEXCEPTION("Event Hub blob root container cannot be empty."); } m_blobNameSuffix.append("Ver"); m_blobNameSuffix.append(std::to_string(eventVersion)); m_blobNameSuffix.append("v0.xml"); } void EventHubCmd::ProcessCmdXml() { Trace trace(Trace::MdsCmd, "EventHubCmd::ProcessCmdXml"); // The MACommandPub.xml contains both notice and publish EH event info. ProcessBlob(GetBlobName("MACommandPub")); } void EventHubCmd::ProcessBlob( std::string&& blobName ) { Trace trace(Trace::MdsCmd, "EventHubCmd::ProcessBlob"); MdsBlobReader blobReader(m_rootContainerSas, std::move(blobName), s_parentContainerName); std::string blobData; const int ntimes = 5; // Because typically EventHubCmd XML blob should be OK to read, if empty data is returned, // retry to avoid any possible storage API failures. for (int i = 0; i < ntimes; i++) { blobData = std::move(blobReader.ReadBlobToString()); if (!blobData.empty() || (ntimes-1) == i) { break; } TRACEINFO(trace, "No EventHubCmd XML is found. Retry index=" << (i+1)); usleep(100*1000*(1<emplace(v[EventNameIndexNotice], xmlItems); TRACEINFO(trace, v[EventNameIndexNotice] << "'s " << xmlItems); } // Older version of MA may not have PublisherVerb auto pubParamsList = paramTable[PublisherVerb]; if (0 == pubParamsList.size()) { Logger::LogInfo("No " + PublisherVerb + " is found."); return; } ValidateCmdBlobParamsList(pubParamsList, PublisherVerb, NPARAMSPub); TRACEINFO(trace, "EventHub dump verb " << PublisherVerb << ":"); for (const auto & v : pubParamsList) { EhCmdXmlItems xmlItems { v[SASIndexPub], v[MdsEndpointIdIndexPub], v[MdsMonikerIndexPub] }; m_pubXmlItemsTable->emplace(v[EventNameIndexPub], xmlItems); TRACEINFO(trace, v[EventNameIndexPub] << "'s " << xmlItems); } } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubCmd.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTHUBCMD__HH__ #define __EVENTHUBCMD__HH__ #include #include #include #include namespace mdsd { // Encapsulating type for EH cmd XML items struct EhCmdXmlItems { std::string sas; // SAS key std::string endpoint; // MDS endpoint ID (e.g., "Test", "Prod", "Stage", ...) std::string moniker; // The mapped storage moniker (may be different from config file account moniker) }; /// /// This class implements functions to handle Event Hub Commands xml files. /// This includes download xml file, parse xml file, and get data from xml. /// class EventHubCmd { public: using EhCmdXmlItemsTable_t = std::unordered_map; /// /// Create the object that'll handle Event Hub command xml file. /// event name space /// event version /// the sas key for the root container /// where the command xml file locates. /// EventHubCmd(std::string eventNameSpace, int eventVersion, std::string rootContainerSas); ~EventHubCmd() {} EventHubCmd(const EventHubCmd & other) = default; EventHubCmd(EventHubCmd&& other) = default; EventHubCmd& operator=(const EventHubCmd& other) = default; EventHubCmd& operator=(EventHubCmd&& other) = default; /// /// Process the Event Hub command XML to extract SASKey and other info. /// void ProcessCmdXml(); /// /// Get Event Hub SAS Keys and return it in table. /// table: key=EventName; value: EH cmd XML items (currently SAS and MDS endpoint ID) /// std::shared_ptr GetNoticeXmlItemsTable() const { return m_noticeXmlItemsTable; } std::shared_ptr GetPublisherXmlItemsTable() const { return m_pubXmlItemsTable; } static void SetParentContainerName(std::string name) { s_parentContainerName = std::move(name); } private: std::string GetBlobName(std::string baseName) { return baseName.append(m_blobNameSuffix); } void ProcessBlob(std::string&& blobName); void ParseCmdXml(std::string&& xmlDoc); private: std::string m_blobNameSuffix; std::string m_rootContainerSas; // key = EventName; value: EH cmd XML items (currently SAS and MDS endpoint ID) std::shared_ptr m_noticeXmlItemsTable; std::shared_ptr m_pubXmlItemsTable; static std::string s_parentContainerName; }; } // namespace mdsd std::ostream& operator<<(std::ostream& str, const mdsd::EhCmdXmlItems & cmd); #endif // __EVENTHUBCMD__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubPublisher.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include "BinaryWriter.hh" #include "EventHubPublisher.hh" #include "MdsCmdLogger.hh" #include "Trace.hh" #include "PublisherStatus.hh" #include "MdsException.hh" using namespace mdsd::details; using namespace web::http; using namespace web::http::client; static std::vector SerializeData( const std::string & text ) { std::vector v; BinaryWriter writer(v); writer.Write(text); return v; } static bool DisableWeakSslCiphers( const std::string & url, web::http::client::native_handle handle ) { const std::string https = "https:"; if (url.size() <= https.size()) { return true; } bool isHttps = (0 == strncasecmp(url.c_str(), https.c_str(), https.size())); if (!isHttps) { return true; } bool resultOK = true; boost::asio::ssl::stream* streamobj = static_cast* >(handle); if (streamobj) { SSL* ssl = streamobj->native_handle(); if (ssl) { const int isOK = 1; const std::string cipherList = "HIGH:!DSS:!RC4:!aNULL@STRENGTH"; if (::SSL_set_cipher_list(ssl, cipherList.c_str()) != isOK) { MdsCmdLogError("Error: failed to disable weak ciphers: " + cipherList + "; URL: " + url); resultOK = false; } } } return resultOK; } EventHubPublisher::EventHubPublisher( const std::string & hostUrl, const std::string & eventHubUrl, const std::string & sasToken) : m_hostUrl(hostUrl), m_eventHubUrl(eventHubUrl), m_sasToken(sasToken), m_httpclient(nullptr), m_resetHttpClient(false) { } // The actual data sent to EventHub is a serialized version of EventDataT::GetData(). // However, because EventDataT::GetData() is std::string, and serialization doesn't // change the size of std::string, use the std::string's size to do validation. static void ValidateData( const mdsd::EventDataT & data ) { if (data.GetData().size() > mdsd::EventDataT::GetMaxSize()) { std::ostringstream strm; strm << "EventHub data is too big: max=" << mdsd::EventDataT::GetMaxSize() << " B; input=" << data.GetData().size() << " B. Drop it."; throw mdsd::TooBigEventHubDataException(strm.str()); } } http_request EventHubPublisher::CreateRequest( const EventDataT & data ) { ValidateData(data); auto serializedData = SerializeData(data.GetData()); http_request req; req.set_request_uri(m_eventHubUrl); req.set_method(methods::POST); req.headers().add("Authorization", m_sasToken); req.headers().add("Content-Type", "application/atom+xml;type=entry;charset=utf-8"); req.set_body(serializedData); for (const auto & it : data.Properties()) { req.headers().add(it.first, it.second); } return req; } void EventHubPublisher::ResetClient() { Trace trace(Trace::MdsCmd, "EventHubPublisher::ResetClient"); if (m_httpclient) { trace.NOTE("Http client will be reset due to previous failure."); m_httpclient.reset(); m_resetHttpClient = false; } auto lambda = [this](web::http::client::native_handle handle)->void { (void) DisableWeakSslCiphers(m_hostUrl, handle); }; http_client_config httpClientConfig; httpClientConfig.set_timeout(std::chrono::seconds(30)); // http request timeout value httpClientConfig.set_nativehandle_options(lambda); m_httpclient = std::move(std::unique_ptr(new http_client(m_hostUrl, httpClientConfig))); } bool EventHubPublisher::Publish( const EventDataT& data ) { if (data.empty()) { MdsCmdLogWarn("Empty data is passed to publisher. Drop it."); return true; } try { if (!m_httpclient || m_resetHttpClient) { ResetClient(); } auto postRequest = CreateRequest(data); auto httpResponse = m_httpclient->request(postRequest).get(); return HandleServerResponse(httpResponse, false); } catch(const mdsd::TooBigEventHubDataException & ex) { MdsCmdLogWarn(ex.what()); return true; } catch(const std::exception & ex) { MdsCmdLogError("Error: EH publish to " + m_eventHubUrl + " failed: " + ex.what()); } catch(...) { MdsCmdLogError("Error: EH publish to " + m_eventHubUrl +" has unknown exception."); } m_resetHttpClient = true; return false; } pplx::task EventHubPublisher::PublishAsync( const EventDataT& data ) { if (data.empty()) { MdsCmdLogWarn("Empty data is passed to async publisher. Drop it."); return pplx::task_from_result(true); } try { if (!m_httpclient || m_resetHttpClient) { ResetClient(); } auto postRequest = CreateRequest(data); auto shThis = shared_from_this(); return m_httpclient->request(postRequest) .then([shThis](pplx::task responseTask) { return shThis->HandleServerResponseAsync(responseTask); }); } catch(const mdsd::TooBigEventHubDataException & ex) { MdsCmdLogWarn(ex.what()); return pplx::task_from_result(true); } catch(const std::exception & ex) { MdsCmdLogError("Error: EH async publish to " + m_eventHubUrl + " failed: " + ex.what()); } m_resetHttpClient = true; return pplx::task_from_result(false); } bool EventHubPublisher::HandleServerResponseAsync( pplx::task responseTask ) { try { return HandleServerResponse(responseTask.get(), true); } catch(const std::exception & e) { MdsCmdLogError("Error: EH async publish to " + m_eventHubUrl + " failed with http response: " + e.what()); } m_resetHttpClient = true; return false; } bool EventHubPublisher::HandleServerResponse( const http_response & response, bool isFromAsync ) { Trace trace(Trace::MdsCmd, "EventHubPublisher::HandleServerResponse"); PublisherStatus pubStatus = PublisherStatus::Idle; auto statusCode = response.status_code(); TRACEINFO(trace, "Http response status_code=" << statusCode << "; Reason='" << response.reason_phrase() << "'"); const int HttpStatusThrottled = 429; std::string errDetails; switch(statusCode) { case status_codes::Created: // 201. According to MSDN, 201 means success. case status_codes::OK: pubStatus = PublisherStatus::PublicationSucceeded; break; case status_codes::BadRequest: pubStatus = PublisherStatus::PublicationFailedWithBadRequest; break; case status_codes::Unauthorized: case status_codes::Forbidden: pubStatus = PublisherStatus::PublicationFailedWithAuthError; errDetails += " SAS: '" + m_sasToken + "'"; break; case status_codes::ServiceUnavailable: pubStatus = PublisherStatus::PublicationFailedServerBusy; m_resetHttpClient = true; break; case HttpStatusThrottled: pubStatus = PublisherStatus::PublicationFailedThrottled; break; default: pubStatus = PublisherStatus::PublicationFailedWithUnknownReason; break; } if (PublisherStatus::PublicationSucceeded != pubStatus) { std::ostringstream strm; strm << "Error: EH publish to " << m_eventHubUrl << errDetails << " failed with status=" << pubStatus << std::boolalpha << ". isAsync=" << isFromAsync; MdsCmdLogError(strm); } else { TRACEINFO(trace, "publication succeeded. isAsync=" << std::boolalpha << isFromAsync); } return (PublisherStatus::PublicationSucceeded == pubStatus); } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubPublisher.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTHUBPUBLISHER__HH__ #define __EVENTHUBPUBLISHER__HH__ #include #include #include #include #include #include "EventData.hh" namespace mdsd { namespace details { /// /// This class implements functions to publish data to EventHub /// service using https. /// class EventHubPublisher : public std::enable_shared_from_this { public: static std::shared_ptr create( const std::string & hostUrl, const std::string & eventHubUrl, const std::string & sasToken ) { return std::shared_ptr(new EventHubPublisher(hostUrl, eventHubUrl, sasToken)); } virtual ~EventHubPublisher() {} EventHubPublisher(const EventHubPublisher &) = delete; EventHubPublisher(EventHubPublisher&&) = default; EventHubPublisher& operator=(EventHubPublisher&) = delete; EventHubPublisher& operator=(EventHubPublisher&&) = default; /// /// Publish the data to Event Hub service synchronously. /// Return true if success, false if any error. /// If input data is empty, drop it and return true. /// virtual bool Publish(const EventDataT & data); /// /// Publish the data to Event Hub service asynchronously. /// Return true if success, false if any error. /// If input data is empty, drop it and return true. /// virtual pplx::task PublishAsync(const EventDataT & data); /// /// Create http request for EventHub data uploading. /// Throw exception if any error for the input data. /// web::http::http_request CreateRequest(const EventDataT & data); protected: EventHubPublisher( const std::string & hostUrl, const std::string & eventHubUrl, const std::string & sasToken); private: void ResetClient(); bool HandleServerResponse(const web::http::http_response & response, bool isFromAsync); bool HandleServerResponseAsync(pplx::task responseTask); private: std::string m_hostUrl; // Event Hub host URL std::string m_eventHubUrl; // Event Hub service URL std::string m_sasToken; // Event Hub SAS token std::unique_ptr m_httpclient; bool m_resetHttpClient; // if true, reset the http client. }; } // namespace details } // namespace mdsd #endif // __EVENTHUBPUBLISHER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubType.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventHubType.hh" #include #include static std::map & GetType2NameMap() { static auto m = new std::map ( { { mdsd::EventHubType::Notice, "EventNotice" }, { mdsd::EventHubType::Publish, "EventPublish" } }); return *m; } std::string mdsd::EventHubTypeToStr(EventHubType type) { auto m = GetType2NameMap(); auto iter = m.find(type); if (iter != m.end()) { return iter->second; } return "unknown"; } static std::map & GetName2TypeMap() { static auto m = new std::map( { { "EventNotice", mdsd::EventHubType::Notice }, { "EventPublish", mdsd::EventHubType::Publish } }); return *m; } mdsd::EventHubType mdsd::EventHubTypeFromStr(const std::string & s) { auto m = GetName2TypeMap(); auto iter = m.find(s); if (iter != m.end()) { return iter->second; } throw std::runtime_error("Invalid EventHubType name: " + s); } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubType.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTHUBTYPE_HH_ #define __EVENTHUBTYPE_HH_ #include namespace mdsd { enum class EventHubType { Notice, Publish }; std::string EventHubTypeToStr(EventHubType type); EventHubType EventHubTypeFromStr(const std::string & s); } // namespace mdsd #endif // __EVENTHUBTYPE_HH_ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploader.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include extern "C" { #include #include } #include #include #include #include "EventHubUploader.hh" #include "MdsException.hh" #include "MdsCmdLogger.hh" #include "Trace.hh" #include "Logger.hh" #include "EventEntry.hh" #include "EventPersistMgr.hh" #include "EventHubPublisher.hh" #include "Utility.hh" using namespace mdsd; using namespace mdsd::details; class UploadInterruptionException {}; EventHubUploader::EventHubUploader( const std::string & persistDir, int32_t persistResendSeconds, int32_t memoryTimeoutSeconds, int32_t maxPersistSeconds ) : m_publisher(nullptr), m_memoryTimeoutSeconds(memoryTimeoutSeconds), m_stopSenderMode(0), m_persistResendSeconds(persistResendSeconds), m_persistResendTimer(crossplat::threadpool::shared_instance().service()), m_persistDir(persistDir), m_pmgr(EventPersistMgr::create(persistDir, maxPersistSeconds)) { } EventHubUploader::~EventHubUploader() { WaitForFinish(); } void EventHubUploader::WaitForFinish( int32_t maxMilliSeconds ) { try { Trace trace(Trace::MdsCmd, "EventHubUploader::WaitForFinish"); if (m_isFinished) { TRACEINFO(trace, "function is already called. abort."); return; } m_isFinished = true; WaitForSenderTask(maxMilliSeconds); if (m_senderTask.valid()) { m_senderTask.get(); } m_persistResendTimer.cancel(); } catch(std::exception& ex) { MdsCmdLogError("Error: EventHubUploader::WaitForFinish failed: " + std::string(ex.what())); } catch(...) { MdsCmdLogError("Error: EventHubUploader::WaitForFinish failed with unknown exception"); } } void EventHubUploader::SetSasAndStart( const std::string & eventHubSas ) { Trace trace(Trace::MdsCmd, "EventHubUploader::SetSasAndStart"); if (eventHubSas.empty()) { MdsCmdLogError("Error: EventHubUploader::SetSasAndStart: unexpected empty EventHub SasKey"); return; } if (m_ehSasKey != eventHubSas) { std::string hostUrl, eventHubUrl, sasToken; ParseEventHubSas(eventHubSas, hostUrl, eventHubUrl, sasToken); m_publisher = EventHubPublisher::create(hostUrl, eventHubUrl, sasToken); // Because the senderTask requires EH publisher object, so // create the task and timer only when EH publisher object is ready. // This only needs to be called once. std::call_once(m_initOnceFlag, &EventHubUploader::Init, this); m_ehSasKey = eventHubSas; } } void EventHubUploader::Init() { m_senderTask = std::async(std::launch::async, &EventHubUploader::Upload, this); m_persistResendTimer.expires_from_now(boost::posix_time::seconds(m_persistResendSeconds)); m_persistResendTimer.async_wait(boost::bind(&EventHubUploader::ResendPersistEvents, this, boost::asio::placeholders::error)); } void EventHubUploader::AddData( const EventDataT & data ) { if (data.empty()) { return; } EventDataT dataCopy{data}; AddData(std::move(dataCopy)); } void EventHubUploader::AddData( EventDataT && data ) { if (data.empty()) { return; } EventEntryT item(new EventEntry(std::move(data))); std::lock_guard lk(m_qmutex); m_uploadQueue.emplace(std::move(item)); m_qcv.notify_all(); } void EventHubUploader::WaitForSenderTask( int32_t milliSeconds ) { Trace trace(Trace::MdsCmd, "EventHubUploader::WaitForSenderTask"); if (m_stopSenderMode > 0) { return; } if (!m_senderTask.valid()) { return; } TRACEINFO(trace, "Notify sender task to stop ..."); // Because condition variable (CV)'s checking for predicate and waiting // is not atomic, to avoid lost notification, the operations that'll // affect predicate results before CV notify() should be protected by // the same mutex for CV wait(). if (-1 == milliSeconds) { std::unique_lock lck(m_qmutex); m_stopSenderMode = StopTaskUntilDoneMode; m_qcv.notify_all(); lck.unlock(); m_senderTask.wait(); } else { m_stopSenderMode = StopTaskUntilDoneMode; m_senderTask.wait_for(std::chrono::milliseconds(milliSeconds)); std::unique_lock lck(m_qmutex); auto queueSize = m_uploadQueue.size(); m_stopSenderMode = StopTaskNowMode; m_qcv.notify_all(); lck.unlock(); TRACEINFO(trace, "Number of Items in upload queue: " << queueSize ); } } void EventHubUploader::Upload() { Trace trace(Trace::MdsCmd, "EventHubUploader::Upload"); try { while(StopTaskNowMode != m_stopSenderMode) { std::unique_lock lk(m_qmutex); m_qcv.wait(lk, [this] { return (m_stopSenderMode || !m_uploadQueue.empty()); }); if (m_uploadQueue.empty()) { break; } UploadInterruptionPoint(); EventEntryT item(std::move(m_uploadQueue.front())); m_uploadQueue.pop(); lk.unlock(); UploadInterruptionPoint(); // item could be re-queued based on process result. ProcessData(std::move(item)); UploadInterruptionPoint(); } } catch(UploadInterruptionException&) { TRACEINFO(trace, "Upload() is interrupted."); } } void EventHubUploader::ProcessData( EventEntryT item ) { Trace trace(Trace::MdsCmd, "EventHubUploader::ProcessData"); auto itemAge = item->GetAgeInSeconds(); std::string itemTag = "Item ("; itemTag += std::to_string(item->GetId()); itemTag += ")"; if (itemAge > m_memoryTimeoutSeconds) { TRACEINFO(trace, itemTag << " age (" << itemAge << " s) > retry timeout(" << m_memoryTimeoutSeconds << " s). Stop retry."); return; } if (!item->IsTimeToRetry()) { std::lock_guard lk(m_qmutex); m_uploadQueue.emplace(std::move(item)); return; } UploadInterruptionPoint(); if(m_publisher->Publish(item->GetData())) { m_nUpSuccess++; return; } UploadInterruptionPoint(); if (item->IsNeverSent()) { item->SetSendTime(); } m_nUpFail++; // if persist write failed, no backoff. retry as soon as possible. bool persistOK = true; if (!item->IsInPersistence()) { trace.NOTE(itemTag + " upload failed. Add to persist and requeue."); persistOK = m_pmgr->Add(item->GetData()); if (!persistOK) { m_npFail++; MdsCmdLogError("Error: EventHubUploader data processor failed to add " + itemTag + " to persist mgr."); } else { item->SetPersistence(); } } else { trace.NOTE(itemTag + " failed again. requeue."); } if (persistOK) { trace.NOTE("Backoff " + itemTag); item->BackOff(); } UploadInterruptionPoint(); std::lock_guard lk(m_qmutex); m_uploadQueue.emplace(std::move(item)); } // input sasKey format: https://tuxtestsb.servicebus.windows.net/Raw?sr=SR&sig=SIG&se=1455131008&skn=writer' // outputs: // - hostUrl: https://tuxtestsb.servicebus.windows.net // - eventHubUrl: https://tuxtestsb.servicebus.windows.net/Raw/messages // - sasToken: SharedAccessSignature sr=SR&sig=SIG&se=1455131008&skn=writer void EventHubUploader::ParseEventHubSas( const std::string & eventHubSas, std::string & hostUrl, std::string & eventHubUrl, std::string & sasToken ) { Trace trace(Trace::MdsCmd, "EventHubUploader::ParseEventHubSas"); std::string prefix{"https://"}; auto prefixLen = prefix.size(); if (eventHubSas.compare(0, prefixLen, prefix)) { std::ostringstream strm; strm << "Invalid Event Hub SAS. SAS is expected to started with '" << prefix << "'"; throw MDSEXCEPTION(strm.str()); } auto hostPos = eventHubSas.find_first_of('/', prefixLen); hostUrl = eventHubSas.substr(0, hostPos); auto eventNamePos = eventHubSas.find_first_of('?', hostUrl.size()); eventHubUrl = eventHubSas.substr(0, eventNamePos) + "/messages"; auto tmpSasToken = eventHubSas.substr(eventNamePos+1); sasToken = MdsdUtil::UnquoteXmlAttribute(tmpSasToken); sasToken = "SharedAccessSignature " + sasToken; } void EventHubUploader::ResendPersistEvents( const boost::system::error_code& error ) { Trace trace(Trace::MdsCmd, "EventHubUploader::ResendPersistEvents"); if (boost::asio::error::operation_aborted == error) { trace.NOTE("Previous timer cancelled."); return; } if (!m_pmgr->UploadAllAsync(m_publisher)) { MdsCmdLogError(std::string("Error: EventHubUploader failed to start async upload. Retry in ") + std::to_string(m_persistResendSeconds) + " seconds."); } if (0 == m_stopSenderMode) { m_persistResendTimer.expires_from_now(boost::posix_time::seconds(m_persistResendSeconds)); m_persistResendTimer.async_wait(boost::bind(&EventHubUploader::ResendPersistEvents, this, boost::asio::placeholders::error)); } } void EventHubUploader::UploadInterruptionPoint() { if (StopTaskNowMode == m_stopSenderMode) { throw UploadInterruptionException(); } } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploader.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTHUBUPLOADER__HH__ #define __EVENTHUBUPLOADER__HH__ #include #include #include #include #include #include #include extern "C" { #include } #include #include "EventData.hh" namespace boost { namespace system { class error_code; } } namespace mdsd { namespace details { class EventEntry; class EventPersistMgr; class EventHubPublisher; } } namespace mdsd { /// /// This class implements the functions to upload data to Event Hub service. /// class EventHubUploader { using EventEntryT = std::unique_ptr; public: /// /// Construct an uploader object. /// Directory fullpath where failed events are persisted. /// How often to resend failed, persisted events /// max time to keep data in memory after first failure. /// Max time to persist failed data. /// EventHubUploader(const std::string & persistDir, int32_t persistResendSeconds = 3600, int32_t memoryTimeoutSeconds = 3600, int32_t maxPersistSeconds = 604800 // 7-days ); ~EventHubUploader(); /// This class uses 'mutex', which is not movable, not copyable. /// So make this class as not movable, not copyable. EventHubUploader(const EventHubUploader& other) = delete; EventHubUploader(EventHubUploader&& other) = delete; EventHubUploader& operator=(const EventHubUploader& other) = delete; EventHubUploader& operator=(EventHubUploader&& other) = delete; /// /// Set Event Hub SAS Key and start the uploader if not started yet. /// When autokey is used, the SAS Key is changed every N hours. This API /// will create a new instance of EventHubPublisher. So it should be called only /// when SasKey is changed. /// NOTE: This API is not thread-safe. /// void SetSasAndStart(const std::string & eventHubSas); /// Add data to Event Hub service. void AddData(const EventDataT & data); void AddData(EventDataT && data); /// /// Wait for given time for all data to be uploaded. /// Return until all data are uploaded or timed out. /// -1 means forever. /// NOTE: this function is not designed for thread-safe. In mdsd, it should /// be called sequentially on given EventHubUploader object. /// void WaitForFinish(int32_t maxMilliSeconds = -1); /// Get number of success uploads. size_t GetNumUploadSuccess() const { return m_nUpSuccess; } /// Get number of failed uploads. size_t GetNumUploadFail() const { return m_nUpFail; } /// Get number of failed persistence size_t GetNumPersistFail() const { return m_npFail; } std::string GetPersistDir() const { return m_persistDir; } private: void WaitForSenderTask(int32_t maxMilliSeconds); void ParseEventHubSas(const std::string & eventHubSas, std::string& hostUrl, std::string& eventHubUrl, std::string& sasToken); void Init(); void ProcessData(EventEntryT data); void Upload(); void ResendPersistEvents(const boost::system::error_code& error); void UploadInterruptionPoint(); private: std::shared_ptr m_publisher; std::string m_ehSasKey; // SASKey for EventHub service size_t m_nUpSuccess = 0; // number of upload success size_t m_nUpFail = 0; // number of upload failure size_t m_npFail = 0; // number of persist mgr failure int32_t m_memoryTimeoutSeconds; // Max time to keep data in memory after first failure. std::queue m_uploadQueue; // To store all events in memory. std::mutex m_qmutex; // For queue/cv synchronization. std::condition_variable m_qcv; // For queue synchronization. static const int StopTaskNowMode = 1; // To stop the sender task immediately. static const int StopTaskUntilDoneMode = 2; // To stop the sender task when all data are processed. std::atomic m_stopSenderMode; // A flag on when to stop the sender task. std::future m_senderTask; // Task to send data to Event Hub service from memory queue. int32_t m_persistResendSeconds = 0; // How often to resend persisted, failed data. boost::asio::deadline_timer m_persistResendTimer; // Persisted data resend timer. std::string m_persistDir; // EventHub data persist dir std::shared_ptr m_pmgr; // Event data persistence manager. std::once_flag m_initOnceFlag; // Once flag to initialize this uploader object. bool m_isFinished = false; // Whether EH uploading operation is finished }; } // namespace mdsd #endif // __EVENTHUBUPLOADER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploaderId.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventHubUploaderId.hh" #include #include #include #include #include using namespace mdsd; EventHubUploaderId::EventHubUploaderId( EventHubType ehtype, const std::string & moniker, const std::string & eventname ) : m_ehtype(ehtype), m_moniker(moniker), m_eventname(eventname) { if (m_moniker.empty()) { throw std::invalid_argument("EventHubUploaderId: invalid empty moniker for event '" + m_eventname + "'"); } if (m_eventname.empty()) { throw std::invalid_argument("EventHubUploaderId: invalid empty eventname for moniker '" + m_moniker + "'"); } } EventHubUploaderId::EventHubUploaderId(const std::string & idstr) { std::vector fields; boost::algorithm::split(fields, idstr, boost::is_any_of(" "), boost::token_compress_on); constexpr size_t nExpected = 3; if (nExpected != fields.size()) { std::ostringstream strm; strm << "Invalid EHUploaderId '" << idstr << "' in number of tokens: expected=" << nExpected << "; actual=" << fields.size(); throw std::runtime_error(strm.str()); } m_eventname = fields[0]; m_moniker = fields[1]; m_ehtype = EventHubTypeFromStr(fields[2]); } ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploaderId.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _EVENTHUBUPLOADERID_HH_ #define _EVENTHUBUPLOADERID_HH_ #include #include "EventHubType.hh" namespace mdsd { struct EventHubUploaderId { EventHubType m_ehtype; std::string m_moniker; std::string m_eventname; EventHubUploaderId(EventHubType ehtype, const std::string & moniker, const std::string & eventname); EventHubUploaderId(const std::string & idstr); operator std::string() const { // put the bits that change more frequently at the front return (m_eventname + " " + m_moniker + " " + EventHubTypeToStr(m_ehtype)); } }; } // namespace mdsd inline std::ostream& operator<<( std::ostream& os, const mdsd::EventHubUploaderId& id ) { os << static_cast(id); return os; } #endif // _EVENTHUBUPLOADERID_HH_ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploaderMgr.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventHubUploaderMgr.hh" #include "EventHubUploaderId.hh" #include "Utility.hh" #include "Trace.hh" #include "Logger.hh" #include #include #include using namespace mdsd; using namespace mdsd::details; EventHubUploaderMgr& EventHubUploaderMgr::GetInstance() { // Because EventHubUploader's destructor will use pplx threadpool tasks, make sure // the static threadpool is created first. First created will be last destroyed. crossplat::threadpool::shared_instance(); static EventHubUploaderMgr s_instance; return s_instance; } bool EventHubUploaderMgr::SetTopLevelPersistDir( const std::string& persistDirTopLevel ) { try { MdsdUtil::ValidateDirRWXByUser(persistDirTopLevel); } catch(std::exception& ex) { Logger::LogError("Error: failed to access directory '" + persistDirTopLevel + "'. Reason: " + ex.what()); return false; } m_persistDirTopLevel = persistDirTopLevel; return true; } std::string EventHubUploaderMgr::CreateAndGetPersistDir( EventHubType ehtype, const std::string& moniker, const std::string& eventname ) { if (m_persistDirTopLevel.empty()) { throw std::runtime_error("Root directory path string for persisting EventHub messages is empty"); } std::string persistDirPath = m_persistDirTopLevel; persistDirPath += "/" + EventHubTypeToStr(ehtype); MdsdUtil::CreateDirIfNotExists(persistDirPath, 01755); persistDirPath += "/" + moniker; MdsdUtil::CreateDirIfNotExists(persistDirPath, 01755); persistDirPath += "/" + eventname; MdsdUtil::CreateDirIfNotExists(persistDirPath, 01755); return persistDirPath; } EventHubUploader* EventHubUploaderMgr::GetUploader( const std::string & uploaderId ) { // support multiple reader threads boost::shared_lock lk(m_mapMutex); auto findResult = m_ehUploaders.find(uploaderId); if (findResult == m_ehUploaders.end()) { return nullptr; } return findResult->second.get(); } // This API assumes m_mapMutex shared lock is already held. std::set> EventHubUploaderMgr::GetNewItemSet( EventHubType ehtype, const std::unordered_map> & eventMonikerMap ) { Trace trace(Trace::MdsCmd, "EventHubUploaderMgr::GetNewItemSet"); std::set> newItemSet; for (const auto & item : eventMonikerMap) { auto & eventname = item.first; auto & monikers = item.second; for (const auto & moniker: monikers) { auto findResult = m_ehUploaders.find(EventHubUploaderId(ehtype, moniker, eventname)); if (findResult == m_ehUploaders.end()) { newItemSet.insert(std::make_pair(moniker, eventname)); } else { TRACEINFO(trace, "Found existing EventHubUploader for moniker=" << moniker << ", event=" << eventname); } } } return newItemSet; } // This API assumes m_mapMutex shared lock is already held. std::set> EventHubUploaderMgr::GetDroppedItemSet( EventHubType ehtype, const std::unordered_map> & eventMonikerMap ) { Trace trace(Trace::MdsCmd, "EventHubUploaderMgr::GetDroppedItemSet"); std::set> droppedItemSet; for (const auto & item : m_ehUploaders) { EventHubUploaderId ehid(item.first); if (ehid.m_ehtype != ehtype) { continue; } auto iter = eventMonikerMap.find(ehid.m_eventname); if (iter == eventMonikerMap.end()) { TRACEINFO(trace, "Event '" << ehid.m_eventname << "' is dropped in MdsdConfig."); droppedItemSet.insert(std::make_pair(ehid.m_moniker, ehid.m_eventname)); } else { auto & monikers = iter->second; for (const auto & moniker: monikers) { if (moniker != ehid.m_moniker) { TRACEINFO(trace, "Event " << ehid.m_eventname << "'s moniker '" << ehid.m_moniker << "' is dropped in MdsdConfig."); droppedItemSet.insert(std::make_pair(ehid.m_moniker, ehid.m_eventname)); } } } } return droppedItemSet; } void EventHubUploaderMgr::CreateUploaders( EventHubType ehtype, const std::unordered_map> & eventMonikerMap ) { Trace trace(Trace::MdsCmd, "EventHubUploaderMgr::CreateUploaders"); if (m_persistDirTopLevel.empty()) { Logger::LogError("Error: EventHub persist directory shouldn't be empty."); return; } try { // This function could be called in multi-threads, or signal handler. use lock to protect. boost::upgrade_lock slock(m_mapMutex); auto newItemSet = GetNewItemSet(ehtype, eventMonikerMap); auto droppedItemSet = GetDroppedItemSet(ehtype, eventMonikerMap); // Do exclusive lock on the EH uploader map if (!newItemSet.empty() || !droppedItemSet.empty()) { boost::upgrade_to_unique_lock< boost::shared_mutex > uniqueLock(slock); for (const auto & item : newItemSet) { auto & moniker = item.first; auto & eventname = item.second; EventHubUploaderId uploaderId(ehtype, moniker, eventname); auto persistDir = CreateAndGetPersistDir(ehtype, moniker, eventname); EhUploader_t newUploader(new EventHubUploader(persistDir)); m_ehUploaders[uploaderId] = std::move(newUploader); TRACEINFO(trace, "Created EventHubUploader for moniker=" << moniker << ", event=" << eventname); } for (const auto & item: droppedItemSet) { auto & moniker = item.first; auto & eventname = item.second; m_ehUploaders.erase(EventHubUploaderId(ehtype, moniker, eventname)); TRACEINFO(trace, "Removed EventHubUploader for moniker=" << moniker << ", event=" << eventname); } } } catch(std::exception& ex) { Logger::LogError("Error: failed to create EventHub uploaders. Reason: " + std::string(ex.what())); } } bool EventHubUploaderMgr::SetSasAndStart( const EventHubUploaderId& uploaderId, const std::string & ehSas ) { const std::string funcname = "EventHubUploaderMgr::SetSasAndStart"; Trace trace(Trace::MdsCmd, funcname); if (ehSas.empty()) { throw std::invalid_argument(funcname + ": unexpected empty SasKey"); } try { auto uploaderObj = GetUploader(uploaderId); if (!uploaderObj) { TRACEINFO(trace, "Cannot find uploader " << uploaderId << "'. Mdsd xml doesn't define it."); return false; } else { TRACEINFO(trace, "SetSasAndStart for " << uploaderId); uploaderObj->SetSasAndStart(ehSas); return true; } } catch(std::exception& ex) { Logger::LogError("Error: EventHubUploaderMgr::SetSasAndStart() failed. Reason: " + std::string(ex.what())); return false; } } bool EventHubUploaderMgr::AddMessageToUpload( const EventHubUploaderId& uploaderId, EventDataT&& eventData ) { const std::string funcname = "EventHubUploaderMgr::AddMessageToUpload"; Trace trace(Trace::Bond, funcname); if (eventData.empty()) { throw std::invalid_argument(funcname + ": unexpected empty EventHub data"); } // The actual data sent to EventHub is a serialized version of EventDataT::GetData(). // However, because EventDataT::GetData() is std::string, and serialization doesn't // change the size of std::string, use the std::string's size to do validation. if (eventData.GetData().size() > EventDataT::GetMaxSize()) { TRACEWARN(trace, "Data size(" << eventData.GetData().size() << ") exceeds max supported size(" << EventDataT::GetMaxSize() << "). Drop it."); return false; } auto uploaderObj = GetUploader(uploaderId); if (!uploaderObj) { std::ostringstream oss; oss << "Error: " << funcname << " cannot find uploader '" << uploaderId << "'."; Logger::LogError(oss.str()); return false; } uploaderObj->AddData(std::move(eventData)); TRACEINFO(trace, "Msg added to EventHubUploader, persistDir: " + uploaderObj->GetPersistDir()); return true; } void EventHubUploaderMgr::WaitForFinish( int32_t maxMilliSeconds ) { Trace trace(Trace::MdsCmd, "EventHubUploaderMgr::WaitForFinish"); for (auto & iter : m_ehUploaders) { iter.second->WaitForFinish(maxMilliSeconds); } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdscommands/EventHubUploaderMgr.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _EVENTHUBUPLOADERMGR_HH_ #define _EVENTHUBUPLOADERMGR_HH_ #include "EventHubUploader.hh" #include "EventHubType.hh" #include #include #include #include #include #include #include #include namespace mdsd { struct EventHubUploaderId; // Using the singleton pattern class EventHubUploaderMgr { public: static EventHubUploaderMgr& GetInstance(); EventHubUploaderMgr(const EventHubUploaderMgr &) = delete; EventHubUploaderMgr(EventHubUploaderMgr&&) = delete; EventHubUploaderMgr& operator=(const EventHubUploaderMgr&) = delete; EventHubUploaderMgr& operator=(EventHubUploaderMgr&&) = delete; bool SetTopLevelPersistDir(const std::string& persistDirTopLevel); /// /// Create EventHub uploaders for different EventHubType, moniker, eventname. /// /// key: eventname; value: monikernames. void CreateUploaders(EventHubType ehtype, const std::unordered_map> & eventMonikerMap); /// /// Set SAS Key for given EventHub uploader identified by an id string. /// Return true if the SAS key is set; return false otherwise. /// bool SetSasAndStart(const EventHubUploaderId& uploaderId, const std::string & ehSas); /// /// Add an EventHub data item to EventHub data uploader identified by an id string. /// Return true if data is added to uploader; return false otherwise. /// bool AddMessageToUpload(const EventHubUploaderId& uploaderId, EventDataT&& eventData); size_t GetNumUploaders() const { return m_ehUploaders.size(); } /// /// Wait for given time for all data to be uploaded. /// Return until all data are uploaded or timed out. /// maxMilliSeconds=-1 means forever. /// void WaitForFinish(int32_t maxMilliSeconds); private: EventHubUploaderMgr() {} ~EventHubUploaderMgr() {} // Top-level directory for persisting EventHub messages. // There'll be a subdirectory for each accountmoniker/eventname combination. std::string m_persistDirTopLevel; // e.g., "/var/mdsd" // Collection of all EHUploader objects typedef std::unique_ptr EhUploader_t; std::map m_ehUploaders; // multiple readers single writer locks for EH uploaders map. // NOTE: C++14 has std::shared_timed_mutex that can do the same thing. But it is not // available until GCC5.0. boost::shared_mutex m_mapMutex; std::string CreateAndGetPersistDir(EventHubType ehtype, const std::string& moniker, const std::string& eventname); EventHubUploader* GetUploader(const std::string & uploaderId); std::set> GetNewItemSet( EventHubType ehtype, const std::unordered_map> & eventMonikerMap); std::set> GetDroppedItemSet( EventHubType ehtype, const std::unordered_map> & eventMonikerMap); }; } // namespace mdsd #endif // _EVENTHUBUPLOADERMGR_HH_ ================================================ FILE: Diagnostic/mdsd/mdscommands/EventPersistMgr.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include extern "C" { #include } #include "EventPersistMgr.hh" #include "MdsCmdLogger.hh" #include "Trace.hh" #include "PersistFiles.hh" #include "EventHubPublisher.hh" #include "Utility.hh" using namespace mdsd::details; EventPersistMgr::EventPersistMgr( const std::string & persistDir, int32_t maxKeepSeconds ) : m_dirname(persistDir), m_persist(new PersistFiles(persistDir)), m_maxKeepSeconds(maxKeepSeconds), m_nFileProcessed{0} { } EventPersistMgr::~EventPersistMgr() { } bool EventPersistMgr::Add( const EventDataT & data ) { if (data.empty()) { return true; } try { return m_persist->Add(data); } catch(std::exception & ex) { MdsCmdLogError(std::string("Error: adding data to persistence hit exception: ") + ex.what()); } return false; } size_t EventPersistMgr::GetNumItems() const { return m_persist->GetNumItems(); } bool EventPersistMgr::UploadAllSync( std::shared_ptr publisher ) const { Trace trace(Trace::MdsCmd, "EventPersistMgr::UploadAllSync"); if (!publisher) { MdsCmdLogError("Error: EventPersistMgr::UploadAllSync(): unexpected NULL for publisher object."); return false; } int nPubErrs = 0; auto endIter = m_persist->cend(); for (auto iter = m_persist->cbegin(); iter != endIter; ++iter) { auto item = *iter; auto ageInSeconds = m_persist->GetAgeInSeconds(item); assert(ageInSeconds >= 0); if (ageInSeconds >= m_maxKeepSeconds) { m_persist->Remove(item); } else { try { auto itemdata = m_persist->Get(item); if (publisher->Publish(itemdata)) { m_persist->Remove(item); } else { nPubErrs++; } } catch(std::exception & ex) { MdsCmdLogError(std::string("Error: EventPersistMgr UploadAllSync() hits exception: ") + ex.what()); nPubErrs++; } usleep(100000); // sleep some time to avoid flush azure service. } } if (nPubErrs) { std::ostringstream strm; strm << "Error: EventPersistMgr UploadAllSync() hit " << nPubErrs << " publication errors."; MdsCmdLogError(strm); } return (0 == nPubErrs); } // Check whether an I/O error is retryable. // NOTE: this list may need to to be adjusted based on actual errors found in the future. // They are obtained from 'man 2 open', 'man 2 read', 'man 2 close'. static inline bool IsFileIOErrorRetryable(int errcode) { switch(errcode) { case EACCES: case EISDIR: case ELOOP: case ENAMETOOLONG: case ENOTDIR: case EOVERFLOW: case EIO: return false; default: return true; } return true; } /// /// A convenient helper function to loop asychronously until a condition is met. /// NOTE: These functions are from CPPREST sample code. /// pplx::task _do_while_iteration(std::function(void)> func) { pplx::task_completion_event ev; func().then([=](bool guard) { ev.set(guard); }); return pplx::create_task(ev); } pplx::task _do_while_impl(std::function(void)> func) { return _do_while_iteration(func).then([=](bool guard) -> pplx::task { if(guard) { return ::_do_while_impl(func); } else { return pplx::task_from_result(false); } }); } pplx::task do_while(std::function(void)> func) { return _do_while_impl(func).then([](bool){}); } std::shared_ptr> EventPersistMgr::GetAllFiles() const { auto fqueue = std::make_shared>(); auto endIter = m_persist->cend(); for (auto iter = m_persist->cbegin(); iter != endIter; ++iter) { auto item = *iter; auto ageInSeconds = m_persist->GetAgeInSeconds(item); assert(ageInSeconds >= 0); if (ageInSeconds >= m_maxKeepSeconds) { m_persist->RemoveAsync(item); } else { fqueue->push(item); } } return fqueue; } static std::shared_ptr> CreateBatch( std::shared_ptr> fullList, size_t batchSize ) { if (fullList->size() <= batchSize) { return fullList; } auto batch = std::make_shared>(); for (size_t i = 0; i < batchSize; i++) { if (fullList->empty()) { break; } batch->push(fullList->front()); fullList->pop(); } return batch; } static void HandlePrevTaskFailure( pplx::task previous_task, const std::string & testname ) { try { previous_task.wait(); } catch(const std::exception& ex) { MdsCmdLogError(testname + " has exception: " + ex.what()); } catch(...) { MdsCmdLogError(testname + " has unknown exception."); } } // Calculate how many batches to use and each batch's size // based on total items to process and max open file resource limit. // // Make sure maxBatches is used. // // The result is that totalItems can be divided into n batches, such that // the first nExtraOne batches have batchSize+1 items, the rest // (nbatches-nExtraOne) has batchSize items. // e.g. totalItems=7, maxBatches=5, we want to have (2,2,1,1,1), where // nbatches=5, batchSize=1, nExtraOne=2. static void CalcBatchInfo( size_t totalItems, size_t& nbatches, size_t& batchSize, size_t& nExtraOne ) { Trace trace(Trace::MdsCmd, "EventPersistMgr::CalcBatchInfo"); auto fdLimit = MdsdUtil::GetNumFileResourceSoftLimit(); if (0 == fdLimit) { // max open file is unlimited, each batch processes one file. nbatches = totalItems; batchSize = 1; nExtraOne = 0; } else { // max batches: 10% of max open files. // so that we won't run out of open files. size_t maxBatches = fdLimit / 10; nbatches = std::min(totalItems, maxBatches); batchSize = totalItems / nbatches; nExtraOne = totalItems % nbatches; } assert((nbatches*batchSize+nExtraOne) == totalItems); TRACEINFO(trace, "total=" << totalItems << "; nbatches=" << nbatches << "; batchSize=" << batchSize << "; nExtraOne=" << nExtraOne); } bool EventPersistMgr::UploadAllAsync( std::shared_ptr publisher ) const { Trace trace(Trace::MdsCmd, "EventPersistMgr::UploadAllAsync"); if (!publisher) { MdsCmdLogError("Error: EventPersistMgr::UploadAllAsync(): unexpected NULL for publisher object."); return false; } auto allFileList = GetAllFiles(); if (allFileList->empty()) { return true; } auto nFilesToProcess = allFileList->size(); size_t nbatches = 0; size_t batchSize = 0; size_t nExtraOne = 0; CalcBatchInfo(nFilesToProcess, nbatches, batchSize, nExtraOne); auto shThis = shared_from_this(); size_t nFilesInBatch = 0; for (size_t i = 0; i < nbatches; i++) { auto nItems = (i < nExtraOne)? (batchSize+1) : batchSize; auto batch = CreateBatch(allFileList, nItems); nFilesInBatch += batch->size(); pplx::task([shThis, publisher, batch]() { shThis->UploadFileBatch(publisher, batch); }); } assert(nFilesInBatch == nFilesToProcess); return true; } // This function will process a list of files by using // one open file handle only. It uses the async task idiom 'do_while' // to process these files in an async task loop. void EventPersistMgr::UploadFileBatch( std::shared_ptr publisher, std::shared_ptr> flist ) const { if (flist->empty()) { return; } auto shThis = shared_from_this(); ::do_while([shThis, flist, publisher]() { if (flist->empty()) { return pplx::task_from_result(false); } auto fileItem = flist->front(); flist->pop(); return shThis->UploadOneFile(publisher, fileItem); }) .then([](pplx::task previous_task) { HandlePrevTaskFailure(previous_task, "UploadFileBatch"); }); } pplx::task EventPersistMgr::UploadOneFile( std::shared_ptr publisher, const std::string & filePath ) const { auto shThis = shared_from_this(); return m_persist->GetAsync(filePath) .then([publisher, shThis, filePath](const EventDataT & fileData) { shThis->ProcessFileData(publisher, filePath, fileData); }) .then([shThis, filePath](pplx::task previous_task) { shThis->m_nFileProcessed++; shThis->HandleReadTaskFailure(previous_task, filePath); return true; }); } void EventPersistMgr::ProcessFileData( std::shared_ptr publisher, const std::string & item, const EventDataT & itemdata ) const { if (itemdata.empty()) { return; } auto shThis = shared_from_this(); publisher->PublishAsync(itemdata) .then([publisher, shThis, item](bool publishOK) { if (publishOK) { shThis->m_persist->RemoveAsync(item) .then([item](bool removeOK) { if (!removeOK) { MdsCmdLogError("Error: EventPersistMgr::ProcessFileData failed to remove file " + MdsdUtil::GetFileBasename(item)); } }); } else { MdsCmdLogError("Error: EventPersistMgr::ProcessFileData failed to upload file " + MdsdUtil::GetFileBasename(item)); } }) .then([item](pplx::task previous_task) { try { previous_task.wait(); } catch(const std::exception& ex) { MdsCmdLogError("Error: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + ". Exception: " + std::string(ex.what())); } catch(...) { MdsCmdLogError("Error: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + " with unknown exception."); } }); } void EventPersistMgr::HandleReadTaskFailure( pplx::task readTask, const std::string & item ) const { try { readTask.wait(); } catch(const std::system_error & ex) { auto ec = ex.code().value(); if (IsFileIOErrorRetryable(ec)) { MdsCmdLogWarn("Warning: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + ". Exception: " + std::string(ex.what()) + ". Retry next time."); } else { MdsCmdLogError("Error: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + ". Exception: " + std::string(ex.what()) + ". Remove file."); m_persist->RemoveAsync(item); } } catch(const std::exception& ex) { // To be conservative: for exception without details, retry them later. MdsCmdLogError("Error: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + ". Exception: " + std::string(ex.what()) + ". Retry next time."); } catch(...) { MdsCmdLogError("Error: failed to publish EH file " + MdsdUtil::GetFileBasename(item) + " with unknown exception."); } } // vim: sw=4 expandtab : ================================================ FILE: Diagnostic/mdsd/mdscommands/EventPersistMgr.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTPERSISTMGR__HH__ #define __EVENTPERSISTMGR__HH__ #include #include #include #include #include #include "EventData.hh" #include namespace mdsd { namespace details { class PersistFiles; class EventHubPublisher; /// /// This class implements the functionality to persist events /// that are failed to be sent to Event Hub. It will save given /// event to persistence and do regular retry on them. When retry /// succeeds, the event will be removed from persistence. /// An event has a max persistence time. After that time, it will /// be removed from persistence. /// class EventPersistMgr : public std::enable_shared_from_this { /// /// Construct a new object. /// Directory name to persist data /// Max seconds to keep the data. After this time, /// It could be removed at any time. /// EventPersistMgr(const std::string & persistDir, int32_t maxKeepSeconds); public: static std::shared_ptr create( const std::string & persistDir, int32_t maxKeepSeconds) { return std::shared_ptr(new EventPersistMgr(persistDir, maxKeepSeconds)); } /// /// NOTE: because this class defines a unique_ptr with forward-declared type, /// the destructor must be implemented in the *cc file. /// ~EventPersistMgr(); // movable but not copyable EventPersistMgr(const EventPersistMgr& other) = delete; EventPersistMgr(EventPersistMgr&& other) = default; EventPersistMgr& operator=(const EventPersistMgr& other) = delete; EventPersistMgr& operator=(EventPersistMgr&& other) = default; /// /// Save given data as persistence object. /// Return true if success, false if any error. /// If data is empty, return true and do nothing. /// bool Add(const EventDataT & data); /// Return number of files on the disk size_t GetNumItems() const; /// /// Return number of files read and processed from persist dir. /// This doesn't include files deleted when they are too old to keep. /// size_t GetNumFileProcessed() const { return m_nFileProcessed; } /// /// Go through each persistence object: if it is too old (beyond /// max keep time), it will be removed. if it is not too old, it /// will be uploaded. If the upload succeeds, it will be removed. /// If upload fails, do nothing to it. /// Return true if success, false if any error. /// bool UploadAllSync(std::shared_ptr publisher) const; /// /// Upload all events asynchronously. This is a "fire and forget" /// function. It doesn't wait for the async tasks to finish. /// Upload failure will be logged but won't be show in this function /// return status. /// Return true if success, false if any error. /// bool UploadAllAsync(std::shared_ptr publisher) const; private: /// /// Process the data read from file, including publishing the data to EventHub. /// If data are empty, do nothing. /// void ProcessFileData(std::shared_ptr publisher, const std::string & item, const EventDataT & itemdata) const; /// /// Handle any GetAsync() task failures. /// void HandleReadTaskFailure(pplx::task readTask, const std::string & item) const; /// /// Return the names of the file to be uploaded. The files that are too old to upload /// will be removed from disk. /// std::shared_ptr> GetAllFiles() const; pplx::task UploadOneFile(std::shared_ptr publisher, const std::string & filePath) const; void UploadFileBatch(std::shared_ptr publisher, std::shared_ptr> flist) const; private: std::string m_dirname; // Persistence directory full path. std::unique_ptr m_persist; // The persist mgr persists the data to files. int32_t m_maxKeepSeconds; // max seconds to keep the data. mutable std::atomic m_nFileProcessed; // number of files read from persistence dir. }; } // namespace details } // namespace mdsd #endif // __EVENTPERSISTMGR__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/MdsBlobReader.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include #include #include "MdsBlobReader.hh" #include "Trace.hh" extern "C" { #include } using namespace azure::storage; using namespace mdsd::details; MdsBlobReader::MdsBlobReader( std::string storageUri, std::string blobName, std::string parentPath ) : m_storageUri(std::move(storageUri)), m_blobName(std::move(blobName)), m_parentPath(std::move(parentPath)) { if (m_storageUri.empty()) { throw MDSEXCEPTION("Storage URI cannot be empty."); } if (!m_parentPath.empty() && m_blobName.empty()) { throw MDSEXCEPTION("Blob name cannot be empty when inside a container."); } } /// Get exception message from exception_ptr static std::string GetEptrMsg(std::exception_ptr eptr) // passing by value is ok { try { if (eptr) { std::rethrow_exception(eptr); } } catch(const std::exception& e) { return e.what(); } return std::string(); } static void HandleStorageException( const storage_exception& ex ) { auto result = ex.result(); auto httpcode = result.http_status_code(); std::ostringstream strm; strm << "Error: storage exception in reading MDS blob: " << "Http status code=" << httpcode << "; " << "Message: " << ex.what() << ". "; auto err = result.extended_error(); if (!err.message().empty()) { strm << "Extended info: " << err.message() << ". "; } auto innerEx = GetEptrMsg(ex.inner_exception()); if (!innerEx.empty()) { strm << "Inner exception: " << innerEx << "."; } MdsCmdLogError(strm); } static operation_context CreateOperationContext(const std::string& reqId) { operation_context op; op.set_client_request_id(reqId); return op; } static blob_request_options BlobRequestOptionsWithRetry() { auto requestOpt = blob_request_options(); exponential_retry_policy retryPolicy; requestOpt.set_retry_policy(retryPolicy); return requestOpt; } cloud_blob MdsBlobReader::GetBlob() const { Trace trace(Trace::MdsCmd, "MdsBlobReader::GetBlob"); web::http::uri webUri = {m_storageUri}; storage_uri uriObj = {webUri}; cloud_blob blob; cloud_blob_container containerObj(uriObj); if (m_parentPath.empty()) { blob = containerObj.get_blob_reference(m_blobName); } else { auto dirObj = containerObj.get_directory_reference(m_parentPath); if (!dirObj.is_valid()) { std::ostringstream strm; strm << "Failed to get container directory '" << m_parentPath << "'."; throw BlobNotFoundException(strm.str()); } blob = dirObj.get_blob_reference(m_blobName); } auto requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); if (!blob.exists(BlobRequestOptionsWithRetry(), op)) { std::ostringstream strm; strm << "Failed to find blob '" << m_blobName << "' in parent path '" << m_parentPath << "'." << "Request id: " << requestId << "."; throw BlobNotFoundException(strm.str()); } return blob; } void MdsBlobReader::ReadBlobToFile( const std::string & filepath ) const { if (filepath.empty()) { throw MDSEXCEPTION("Filepath name to save blob data cannot be empty."); } std::string requestId; try { auto blob = GetBlob(); requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); blob.download_to_file(filepath, access_condition(), BlobRequestOptionsWithRetry(), op); } catch(const storage_exception & ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } catch(const BlobNotFoundException& ex) { MdsCmdLogWarn("Specified blob " + m_blobName + " is not found: " + ex.what()); } } std::string MdsBlobReader::ReadBlobToString() const { Trace trace(Trace::MdsCmd, "MdsBlobReader::ReadBlobToString"); std::string requestId; try { auto blob = GetBlob(); requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); auto streamObj = blob.open_read(access_condition(), BlobRequestOptionsWithRetry(), op); concurrency::streams::container_buffer cbuf; streamObj.read_to_end(cbuf).get(); streamObj.close(); return cbuf.collection(); } catch(const storage_exception & ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } catch(const BlobNotFoundException & ex) { MdsCmdLogWarn("Specified blob " + m_blobName + " is not found: " + ex.what()); } return std::string(); } pplx::task MdsBlobReader::ReadBlobToStringAsync() const { Trace trace(Trace::MdsCmd, "MdsBlobReader::ReadBlobToStringAsync"); std::string requestId; try { auto blob = GetBlob(); requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); auto asyncReadTask = blob.open_read_async(access_condition(), BlobRequestOptionsWithRetry(), op); return asyncReadTask.then([=](concurrency::streams::istream streamObj) { try { concurrency::streams::container_buffer cbuf; streamObj.read_to_end(cbuf).get(); streamObj.close(); return cbuf.collection(); } catch (const storage_exception& ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } return std::string(); }); } catch(const storage_exception & ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } catch(const BlobNotFoundException & ex) { MdsCmdLogWarn("Specified blob " + m_blobName + " is not found: " + ex.what()); } return pplx::task([](){ return std::string(); }); } uint64_t MdsBlobReader::GetLastModifiedTimeStamp( std::function blobNotFoundExHandler) const { uint64_t lastModifiedTimeStamp = 0; std::string requestId; try { auto blob = GetBlob(); requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); blob.download_attributes(access_condition(), BlobRequestOptionsWithRetry(), op); lastModifiedTimeStamp = blob.properties().last_modified().to_interval(); } catch(const storage_exception & ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } catch(const BlobNotFoundException & ex) { blobNotFoundExHandler(this, ex); } return lastModifiedTimeStamp; } pplx::task MdsBlobReader::GetLastModifiedTimeStampAsync( std::function blobNotFoundExHandler) const { uint64_t lastModifiedTimeStamp = 0; std::string requestId; try { auto blob = GetBlob(); requestId = utility::uuid_to_string(utility::new_uuid()); auto op = CreateOperationContext(requestId); auto asyncAttrDownloadTask = blob.download_attributes_async(access_condition(), BlobRequestOptionsWithRetry(), op); return asyncAttrDownloadTask.then([=]() { return blob.properties().last_modified().to_interval(); }); } catch(const storage_exception & ex) { HandleStorageException(ex); if (!requestId.empty()) { MdsCmdLogError("Request id: " + requestId); } } catch(const BlobNotFoundException & ex) { blobNotFoundExHandler(this, ex); } return pplx::task([=]() { return lastModifiedTimeStamp; // = 0 }); } ================================================ FILE: Diagnostic/mdsd/mdscommands/MdsBlobReader.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSBLOBREADER__HH__ #define __MDSBLOBREADER__HH__ #include #include #include #include #include "MdsException.hh" #include "MdsCmdLogger.hh" namespace azure { namespace storage { class cloud_blob; } } // namespace azure namespace mdsd { namespace details { /// /// Implement a class to read blob from azure storage related to MDS. /// class MdsBlobReader { public: /// /// Construct a new blob reader. /// The absolute URI to the blob root container /// blob name /// the given blob's parent container name /// MdsBlobReader(std::string storageUri, std::string blobName = "", std::string parentPath = ""); ~MdsBlobReader() {} MdsBlobReader(const MdsBlobReader& other) = default; MdsBlobReader(MdsBlobReader&& other) = default; MdsBlobReader& operator=(const MdsBlobReader& other) = default; MdsBlobReader& operator=(MdsBlobReader&& other) = default; /// Read current blob object to a given file. void ReadBlobToFile(const std::string & filepath) const; /// /// Read current blob object to a string. /// Return the blob content, or empty string if any error. /// std::string ReadBlobToString() const; /// /// Start async reading of current blob object to a string. /// Return the task whose result will be the string. /// pplx::task ReadBlobToStringAsync() const; /// /// Returns the read blob's LMT (# seconds since epoch). /// 0 will be returned if blob doesn't exist /// or if any exception is thrown (e.g., storage exception) /// uint64_t GetLastModifiedTimeStamp( std::function blobNotFoundExHandler) const; /// /// Start async reading of blob's LMT (# seconds since epoch). /// Return the task whose result will be the the blob's LMT. /// pplx::task GetLastModifiedTimeStampAsync( std::function blobNotFoundExHandler) const; // Typical BlobNotFoundException handlers provided here static void DoNothingBlobNotFoundExHandler(const MdsBlobReader*, const BlobNotFoundException&) {} static void LogWarnBlobNotFoundExHandler(const MdsBlobReader*, const BlobNotFoundException& ex) { MdsCmdLogWarn("Specified blob is not found: " + std::string(ex.what())); } private: /// /// Get current blob object. /// azure::storage::cloud_blob GetBlob() const; private: std::string m_storageUri; std::string m_blobName; std::string m_parentPath; }; } // namespace details } // namespace mdsd #endif // __MDSBLOBREADER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/MdsCmdLogger.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSCMDLOGGER__HH__ #define __MDSCMDLOGGER__HH__ #include "Logger.hh" namespace mdsd { namespace details { inline void MdsCmdLogError(const std::string & msg) { Logger::LogError("MDSCMD " + msg); } inline void MdsCmdLogError(const std::ostringstream& strm) { MdsCmdLogError(strm.str()); } inline void MdsCmdLogWarn(const std::string & msg) { Logger::LogWarn("MDSCMD " + msg); } inline void MdsCmdLogWarn(const std::ostringstream& strm) { MdsCmdLogWarn(strm.str()); } } // namespace details } // namespace mdsd #endif // __MDSCMDLOGGER__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/MdsException.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "MdsException.hh" using namespace mdsd; static std::string GetFileBasename( const std::string & filepath ) { auto p = filepath.find_last_of('/'); if (p == std::string::npos) { return filepath; } return filepath.substr(p+1); } MdsException::MdsException( const char* filename, int lineno, const std::string & message) : std::exception() { std::ostringstream strm; if (filename) { strm << GetFileBasename(filename) << ":" << lineno << " "; } strm << message; m_msg = strm.str(); } MdsException::MdsException( const char* filename, int lineno, const char* message) : std::exception() { if (message) { std::ostringstream strm; if (filename) { strm << GetFileBasename(filename) << ":" << lineno << " "; } strm << message; m_msg = strm.str(); } } ================================================ FILE: Diagnostic/mdsd/mdscommands/MdsException.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSEXCEPTION__HH__ #define __MDSEXCEPTION__HH__ #include #include #define MDSEXCEPTION(message) \ mdsd::MdsException(__FILE__, __LINE__, message) /**/ namespace mdsd { class MdsException : public std::exception { private: std::string m_msg; public: MdsException(const char* filename, int lineno, const std::string & message); MdsException(const char* filename, int lineno, const char* message); virtual const char * what() const noexcept { return m_msg.c_str(); } }; class BlobNotFoundException : public std::exception { private: std::string m_msg; public: BlobNotFoundException(std::string message) noexcept : std::exception(), m_msg(std::move(message)) {} virtual const char * what() const noexcept { return m_msg.c_str(); } }; class TooBigEventHubDataException : public MdsException { public: TooBigEventHubDataException(const std::string & msg) : MdsException(nullptr, 0, msg) {} TooBigEventHubDataException(const char* msg) : MdsException(nullptr, 0, msg) {} }; } #endif // __MDSEXCEPTION__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/PersistFiles.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include #include extern "C" { #include #include #include } #include "PersistFiles.hh" #include "MdsCmdLogger.hh" #include "Trace.hh" #include "MdsException.hh" #include "Utility.hh" using namespace mdsd; using namespace mdsd::details; // If the filepath exists and it is a dir, return true; // otherwise, return false. static bool IsDirExists( const std::string& filepath ) { struct stat sb; auto rtn = stat(filepath.c_str(), &sb); mode_t mode = sb.st_mode; return (0 == rtn && S_ISDIR(mode)); } PersistFiles::PersistFiles( const std::string & dirname ) : m_dirname(dirname), m_suffix("XXXXXX"), m_fileTemplate(new char[dirname.size()+m_suffix.size()+2]) { if (!IsDirExists(m_dirname)) { throw MDSEXCEPTION(std::string("Failed to find directory '") + m_dirname + "'."); } snprintf(m_fileTemplate.get(), dirname.size()+2, "%s/", dirname.c_str()); } int PersistFiles::CreateUniqueFile() const { // reset template for mkstemp auto offset = m_dirname.size()+1; auto sz = m_suffix.size() + 1; snprintf(m_fileTemplate.get()+offset, sz, "%s", m_suffix.c_str()); int fd = mkstemp(m_fileTemplate.get()); if (-1 == fd) { auto errnum = errno; std::error_code ec(errnum, std::system_category()); std::ostringstream strm; strm << "Error: creating unique persist file with mkstemp() failed. errno=" << errnum << "; Reason: " << ec.message(); MdsCmdLogError(strm); } return fd; } bool PersistFiles::Add( const EventDataT& data ) const { if (data.empty()) { return true; } auto fd = CreateUniqueFile(); if (fd < 0) { return false; } MdsdUtil::FdCloser fdCloser(fd); bool resultOK = true; auto datastr = data.Serialize(); if (-1 == write(fd, datastr.c_str(), datastr.size())) { std::error_code ec(errno, std::system_category()); MdsCmdLogError("Error: write() to persist file failed. Reason: " + ec.message()); resultOK = false; } return resultOK; } EventDataT PersistFiles::Get( const std::string& filepath ) const { if (filepath.empty()) { throw MDSEXCEPTION("Empty string is used for file path parameter."); } std::ifstream fin(filepath); if (!fin) { throw MDSEXCEPTION("Failed to open file '" + filepath + "'."); } fin.seekg(0, fin.end); size_t fsize = fin.tellg(); fin.seekg(0, fin.beg); std::vector buf(fsize); fin.read(buf.data(), fsize); fin.close(); return EventDataT::Deserialize(buf.data(), fsize); } bool PersistFiles::Remove( const std::string& filepath ) const { if (filepath.empty()) { return true; } if (remove(filepath.c_str())) { std::error_code ec(errno, std::system_category()); MdsCmdLogError("Error: failed to remove persist file '" + filepath + "'. Reason: " + ec.message()); return false; } return true; } pplx::task PersistFiles::RemoveAsync( const std::string& filepath ) const { Trace trace(Trace::MdsCmd, "PersistFiles::RemoveAsync"); if (filepath.empty()) { return pplx::task_from_result(true); } return pplx::task([=]() -> bool { return Remove(filepath); }) .then([](pplx::task previous_task) { try { return previous_task.get(); } catch(std::exception& ex) { MdsCmdLogError("PersistFiles::RemoveAsync failed with " + std::string(ex.what())); } catch(...) { MdsCmdLogError("PersistFiles::RemoveAsync failed with unknown exception."); } return false; }); } int32_t PersistFiles::GetAgeInSeconds( const std::string & filepath ) const { struct stat sb; auto rtn = stat(filepath.c_str(), &sb); if (rtn) { std::error_code ec(errno, std::system_category()); MdsCmdLogError("Error: failed to locate persist file '" + filepath + "'. Reason: " + ec.message()); return -1; } auto now = time(nullptr); return static_cast(now - sb.st_mtime); } PersistFiles::const_iterator PersistFiles::cbegin() const { DirectoryIter diter{m_dirname}; return diter; } PersistFiles::const_iterator PersistFiles::cend() const { DirectoryIter diter; return diter; } size_t PersistFiles::GetNumItems() const { size_t count = 0; auto endIter = cend(); for (auto iter = cbegin(); iter != endIter; ++iter) { count++; } return count; } pplx::task PersistFiles::GetAsync( const std::string & filepath ) const { if (filepath.empty()) { MdsCmdLogError("Error: GetAsync: unexpected empty filepath."); return pplx::task_from_result(EventDataT()); } return concurrency::streams::file_stream::open_istream(filepath) .then([filepath](concurrency::streams::basic_istream inFile) { if (!inFile.is_open()) { MdsCmdLogError("Error: PersistFiles failed to open file '" + filepath + "'."); return pplx::task_from_result(EventDataT()); } else { concurrency::streams::container_buffer buf; return inFile.read_to_end(buf) .then([inFile, filepath, buf](size_t bytesRead) { inFile.close(); if (bytesRead > 0) { return pplx::task_from_result(EventDataT::Deserialize(buf.collection())); } MdsCmdLogError("Error: no data is read from '" + filepath + "', unexpected empty file."); return pplx::task_from_result(EventDataT()); }); } }); } ================================================ FILE: Diagnostic/mdsd/mdscommands/PersistFiles.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __PERSISTFILES__HH__ #define __PERSISTFILES__HH__ #include #include #include "DirectoryIter.hh" #include "EventData.hh" namespace mdsd { namespace details { class PersistFiles { public: typedef DirectoryIter const_iterator; /// /// Constructor. It will persist files to given directory. /// Throw MdsException if it fails to access the directory. /// PersistFiles(const std::string & dirname); virtual ~PersistFiles() {} /// /// Add given data to a new, unique file. /// Return true if success, false if any error. /// If 'data' is empty, return true and do nothing. /// bool Add(const EventDataT& data) const; /// /// Get the content of the file given filepath. /// Return file content or throw exception if any error. /// EventDataT Get(const std::string& filepath) const; /// /// Get the content of the file asynchronously given filepath. /// Return the task for file content, or task for empty string if any error. /// pplx::task GetAsync(const std::string& filepath) const; /// /// Remove a filepath. /// Return true if success, false if any error. /// bool Remove(const std::string & filepath) const; /// /// Remove a filepath asynchronously. /// Return true if success, false if any error. /// pplx::task RemoveAsync(const std::string & filepath) const; /// /// Get a file's last modification time. /// If the file doesn't exit, return -1. /// int32_t GetAgeInSeconds(const std::string & filepath) const; const_iterator cbegin() const; const_iterator cend() const; /// /// Get number of items in persist. /// size_t GetNumItems() const; private: /// /// Create a unique file. Return an open file descriptor, or -1 if any error. /// int CreateUniqueFile() const; private: std::string m_dirname; std::string m_suffix; std::unique_ptr m_fileTemplate; }; } // namespace details } // namespace mdsd #endif // __PERSISTFILES__HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/PublisherStatus.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include "PublisherStatus.hh" using namespace mdsd::details; // To prevent static initialization order fiasco. // see https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use static std::map & GetPublisherStatusMap() { static auto enumMap = new std::map( { { PublisherStatus::Idle, "Idle" }, { PublisherStatus::PublicationSucceeded, "PublicationSucceeded" }, { PublisherStatus::PublicationFailedWithUnknownReason, "PublicationFailedWithUnknownReason" }, { PublisherStatus::PublicationFailedWithBadRequest, "PublicationFailedWithBadRequest" }, { PublisherStatus::PublicationFailedWithAuthError, "PublicationFailedWithAuthError" }, { PublisherStatus::PublicationFailedServerBusy, "PublicationFailedServerBusy" }, { PublisherStatus::PublicationFailedThrottled, "PublicationFailedThrottled" } }); return *enumMap; } std::ostream& operator<<( std::ostream& os, PublisherStatus status ) { auto enumMap = GetPublisherStatusMap(); auto iter = enumMap.find(status); if (iter == enumMap.end()) { os << "Unknown PublisherStatus"; } else { os << iter->second; } return os; } ================================================ FILE: Diagnostic/mdsd/mdscommands/PublisherStatus.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __PUBLISHERSTATUS_HH__ #define __PUBLISHERSTATUS_HH__ #include namespace mdsd { namespace details { enum class PublisherStatus { /// /// Object has not started any work. /// Idle, /// /// The last publication attempt succeeded. /// PublicationSucceeded, /// /// The last publication attempt failed. /// PublicationFailedWithUnknownReason, /// /// The last publication attempt failed with bad request error. /// PublicationFailedWithBadRequest, /// /// The last publication attempt failed with auth error. /// PublicationFailedWithAuthError, /// /// The last publication attempt failed because server is busy, need to retry later. /// PublicationFailedServerBusy, /// /// The last publication attempt failed because of throttled, need to retry later. /// PublicationFailedThrottled }; } // namespace details } // namespace mdsd std::ostream& operator<<( std::ostream& os, mdsd::details::PublisherStatus status ); #endif // __PUBLISHERSTATUS_HH__ ================================================ FILE: Diagnostic/mdsd/mdscommands/commands.xsd ================================================ ================================================ FILE: Diagnostic/mdsd/mdsd/Batch.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Batch.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "MdsEntityName.hh" #include "CanonicalEntity.hh" #include "IMdsSink.hh" #include "Trace.hh" #include "Logger.hh" #include "Utility.hh" #include using std::string; Batch::Batch(MdsdConfig* config, const MdsEntityName& target, const Credentials* creds, int interval) : _config(config), _batchQIBase(0), _interval(interval), _sink(IMdsSink::CreateSink(config, target, creds)), _dirty(false) { Trace trace(Trace::Batching, "Batch constructor"); if (trace.IsActive()) { std::ostringstream msg; msg << "Created batch " << this << " (eventName " << target << " QI " << interval << ")"; trace.NOTE(msg.str()); } _sink->ValidateAccess(); } void Batch::AddRow(const CanonicalEntity & row) { Trace trace(Trace::Batching, "Batch::AddRow"); if (trace.IsActive()) { std::ostringstream msg; msg << "Batch " << this << " add CE " << row; trace.NOTE(msg.str()); } MdsTime qibase = row.GetPreciseTimeStamp().Round(_interval); std::lock_guard lock(_mutex); // If the Query Interval Base has changed, then flush the batch. if (qibase != _batchQIBase) { if (trace.IsActive()) { std::ostringstream msg; msg << "Query Interval Base changed from " << _batchQIBase << " to " << qibase; trace.NOTE(msg.str()); } _sink->Flush(); _batchQIBase = qibase; } _sink->AddRow(row, qibase); // May cause flush... MarkTime(); } // Add a row to a batch destined for SchemasTable in some storage acct. This is a helper function // for building these rows correctly. void Batch::AddSchemaRow(const MdsEntityName &target, const string &hash, const string &schema) { Trace trace(Trace::Batching, "Batch::AddSchemaRow"); CanonicalEntity row(3); row.AddColumn("PhysicalTableName", target.Basename()); row.AddColumn("MD5Hash", hash); row.AddColumn("Schema", schema); AddRow(row); } void Batch::Flush() { Trace trace(Trace::Batching, "Batch::Flush"); if (IsClean()) return; if (trace.IsActive()) { std::ostringstream msg; msg << "Batch " << this; trace.NOTE(msg.str()); } std::lock_guard lock(_mutex); MarkFlushed(); _sink->Flush(); } Batch::~Batch() { Trace trace(Trace::Batching, "Batch::~Batch"); if (IsDirty()) Flush(); delete _sink; } bool Batch::HasStaleData() const { Trace trace(Trace::Batching, "Batch::HasStaleData"); // I want data to not linger past the end of the *next* QI. If the QI size is 5 minutes and data is // written at 00:01:00, that data becomes stale at 00:10:00. if (IsClean()) return false; MdsTime trigger = (MdsTime::Now() - _interval).Round(_interval); if (trace.IsActive()) { std::ostringstream msg; msg << "_lastAction=" << _lastAction << " _interval=" << _interval << " trigger=" << trigger; trace.NOTE(msg.str()); } return (_lastAction < trigger); } std::ostream& operator<<(std::ostream& os, const Batch& batch) { os << &batch << " (QIBase " << batch._batchQIBase << ", Interval " << batch._interval << ", Sink " << batch._sink << ")"; return os; } Batch* BatchSet::GetBatch(const MdsEntityName &target, int interval) { Trace trace(Trace::Batching, "BatchSet::GetBatch"); auto creds = target.GetCredentials(); key_t key = std::make_pair(target.Basename(), creds); std::ostringstream keystring; if (trace.IsActive()) { keystring << "<" << target.Basename() << ", 0x" << creds << ">"; } std::lock_guard lock(_mutex); // Lock held until this function returns std::map::iterator iter = _map.find(key); if (iter != _map.end()) { trace.NOTE("Found batch for " + keystring.str()); return iter->second; } trace.NOTE("Creating batch for " + keystring.str()); std::ostringstream msg; // Bug 3532559: Batch constructor can fail if XTableSink constructor fails while // creating an XTableRequest. So wrap the constructor in a try/catch block. try { Batch *batch = new Batch(_config, target, creds, interval); if (trace.IsActive()) { std::ostringstream msg; msg << "New batch " << *batch; trace.NOTE(msg.str()); } _map[key] = batch; return batch; } catch (const std::exception& e) { msg << "GetBatch(" << target << ") failed to create new batch for " << keystring.str() << ": " << e.what(); } catch (...) { msg << "GetBatch(" << target << ") caught unknown exception"; } // If we got here, we caught an exception and already created the error message Logger::LogError(msg.str()); trace.NOTE(msg.str()); return nullptr; } void BatchSet::Flush() { Trace trace(Trace::Batching, "BatchSet::Flush"); // Walk the _map and flush all the dirty Batches for (const auto &iter : _map) { if (iter.second->IsDirty()) { iter.second->Flush(); } } } void BatchSet::FlushIfStale() { Trace trace(Trace::Batching, "BatchSet::FlushIfStale"); // Walk the _map and flush all the Batches for (const auto &item : _map) { if (item.second->HasStaleData()) { item.second->Flush(); } } } BatchSet::~BatchSet() { Trace trace(Trace::Batching, "BatchSet::~BatchSet"); // Walk the _map and delete all the Batches; deleting them will Flush() them first for (auto &iter : _map) { delete iter.second; } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Batch.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _BATCH_HH_ #define _BATCH_HH_ #define MAX_BATCH_SIZE 100 #include #include #include #include #include #include #include #include "MdsTime.hh" #include "MdsEntityName.hh" class MdsdConfig; class MdsValue; class Credentials; class CanonicalEntity; class IMdsSink; class Batch { friend std::ostream& operator<<(std::ostream& os, const Batch& batch); public: /// Force all batched entities into MDS and leave the batch empty void Flush(); /// Add a row to the batch. May trigger a flush which may or may not flush this row. /// The row to be added to the batch. Must be complete (includes all columns). Contents /// are copied elsewhere by the sink; caller can reuse the object if desired. /// The PartitionKey for the row. /// The QueryInterval to which this row is associated. void AddRow(const CanonicalEntity &row); /// Add a row to a batch of entries destined for some SchemasTable void AddSchemaRow(const MdsEntityName &target, const std::string &hash, const std::string &schema); // True if the batch might have rows from a prior query interval bool HasStaleData() const; // { return (_lastAction < (_batchQIBase + _interval)); } ~Batch(); private: Batch(MdsdConfig* config, const MdsEntityName &target, const Credentials* creds, int interval); Batch(); // No void constructor Batch(const Batch&); // No copy constructor Batch& operator=(const Batch&); // Can't assign /// Update the _lastAction time. void MarkTime() { _lastAction.Touch(); _dirty = true; } void MarkFlushed() { _lastAction = MdsTime::Max(); _dirty = false; } bool IsDirty() const { return _dirty; } bool IsClean() const { return (! _dirty); } MdsdConfig *_config; MdsTime _lastAction; // Used to find lingering batches MdsTime _batchQIBase; // The Query Interval base timestamp for the current batch int _interval; // The width of the interval (in seconds) IMdsSink* _sink; std::recursive_mutex _mutex; bool _dirty; // "Dirty" bit; set if any AddRow was called since last flush friend class BatchSet; }; std::ostream& operator<<(std::ostream& os, const Batch& batch); class BatchSet { public: BatchSet(MdsdConfig* c) : _config(c) {} ~BatchSet(); // Get pointer to a Batch object for this table // The metadata for the destination for the batch's data // The "query interval" for the batch, i.e. how often it gets flushed Batch* GetBatch(const MdsEntityName &target, int interval); void Flush(); void FlushIfStale(); private: using key_t = std::pair; BatchSet(const BatchSet&); // No copy constructor BatchSet& operator=(const BatchSet&); // No copying std::map _map; MdsdConfig* _config; std::mutex _mutex; // Just covers the BatchSet object, not any of the Batches in the set }; #endif // _BATCH_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CMakeLists.txt ================================================ SET(CMAKE_SKIP_BUILD_RPATH FALSE) SET(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE) # Reset rpath vars for static executable SET(CMAKE_INSTALL_RPATH "${OMI_LIB_PATH}") SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) set(LINKER_FLAGS "${LINKER_FLAGS} -static-libgcc -static-libstdc++") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${LINKER_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") set(LDFLAGS "-rdynamic") set(LDFLAGS "${LDFLAGS} -Wl,--wrap=memcpy") # To force using memcpy@GLIBC_2.2.5 (for old distro versions) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${LDFLAGS}") if(NOT BUILD_NUMBER) execute_process( COMMAND date +%s OUTPUT_VARIABLE BUILD_NUMBER) endif(NOT BUILD_NUMBER) message("Build number: ${BUILD_NUMBER}") add_definitions(-DBUILD_NUMBER=${BUILD_NUMBER}) include_directories( /usr/include/libxml2 /usr/local/include ${OMI_INCLUDE_DIRS} ${CASABLANCA_INCLUDE_DIRS} ${STORAGE_INCLUDE_DIRS} ${CMAKE_SOURCE_DIR}/mdsd ${CMAKE_SOURCE_DIR}/mdsdinput ${CMAKE_SOURCE_DIR}/mdsdlog ${CMAKE_SOURCE_DIR}/mdsdutil ${CMAKE_SOURCE_DIR}/mdscommands ${CMAKE_SOURCE_DIR}/mdsdcfg ) # include(/usr/local/lib/bond/bond.cmake) link_directories( ${OMI_LIB_PATH} ) if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") # Some dependency library has no static clang lib, so use shared ones. set(XML_LIB xml++-2.6${LIBSUFFIX}) set(GLIBMM_LIB glibmm-2.4${LIBSUFFIX}) set(SIGC_LIB sigc-2.0${LIBSUFFIX}) set(BOOST_LIBS boost_log${LIBSUFFIX} boost_iostreams${LIBSUFFIX} boost_regex${LIBSUFFIX} boost_thread${LIBSUFFIX} boost_system${LIBSUFFIX} ) else() # For gcc, use static libs set(XML_LIB /usr/lib/libxml++-2.6.a) set(GLIBMM_LIB /usr/lib/x86_64-linux-gnu/libglibmm-2.4.a) set(SIGC_LIB /usr/lib/x86_64-linux-gnu/libsigc-2.0.a) set(BOOST_LIBS /usr/lib/x86_64-linux-gnu/libboost_log.a /usr/lib/x86_64-linux-gnu/libboost_iostreams.a /usr/lib/x86_64-linux-gnu/libboost_regex.a /usr/lib/x86_64-linux-gnu/libboost_thread.a /usr/lib/x86_64-linux-gnu/libboost_system.a ) endif() set(COMM_LIBS micxx${LIBSUFFIX} omiclient${LIBSUFFIX} rt # Required not to use clock_gettime@GLIBC_2.17 /usr/lib/x86_64-linux-gnu/bond/libbond${LIBSUFFIX}.a ${LINKSTDLIB} ${STORAGE_LIBRARIES} ${CASABLANCA_LIBRARIES} ${XML_LIB} ${GLIBMM_LIB} /usr/lib/x86_64-linux-gnu/libglib-2.0.a ${SIGC_LIB} /usr/lib/x86_64-linux-gnu/libpcre.a /usr/lib/x86_64-linux-gnu/libuuid.a /usr/lib/x86_64-linux-gnu/libxml2.a /usr/lib/x86_64-linux-gnu/libz.a /usr/lib/x86_64-linux-gnu/liblzma.a ${BOOST_LIBS} /usr/local/lib/libssl.a /usr/local/lib/libcrypto.a dl ) set(SOURCES Batch.cc CanonicalEntity.cc CfgContext.cc CfgCtxAccounts.cc CfgCtxDerived.cc CfgCtxEnvelope.cc CfgCtxError.cc CfgCtxEtw.cc CfgCtxEventAnnotations.cc CfgCtxEvents.cc CfgCtxExtensions.cc CfgCtxHeartBeats.cc CfgCtxImports.cc CfgCtxManagement.cc CfgCtxMdsdEvents.cc CfgCtxMonMgmt.cc CfgCtxOMI.cc CfgCtxParser.cc CfgCtxRoot.cc CfgCtxSvcBusAccts.cc CfgCtxSchemas.cc CfgCtxSources.cc cJSON.c CmdLineConverter.cc ConfigParser.cc Constants.cc Credentials.cc cryptutil.cc DaemonConf.cc DerivedEvent.cc Engine.cc EtwEvent.cc EventJSON.cc ExtensionMgmt.cc FileSink.cc IMdsSink.cc ITask.cc LADQuery.cc Listener.cc LocalSink.cc MdsdConfig.cc MdsdMetrics.cc mdsd.cc MdsEntityName.cc MdsSchemaMetadata.cc MdsValue.cc Memcheck.cc OMIQuery.cc OmiTask.cc Pipeline.cc PipeStages.cc Priority.cc ProtocolHandlerBase.cc ProtocolHandlerBond.cc ProtocolHandlerJSON.cc ProtocolListener.cc ProtocolListenerBond.cc ProtocolListenerDynamicJSON.cc ProtocolListenerJSON.cc ProtocolListenerMgr.cc ProtocolListenerTcpJSON.cc RowIndex.cc SaxParserBase.cc SchemaCache.cc Signals.c StoreType.cc StreamListener.cc Subscription.cc TableSchema.cc TermHandler.cc Version.cc XJsonBlobBlockCountsMgr.cc XJsonBlobRequest.cc XJsonBlobSink.cc XTableConst.cc XTableHelper.cc XTableRequest.cc XTableSink.cc ) # To set source file specific compile flags, do # set_source_files_properties( PROPERTIES COMPILE_FLAGS ) # example: # set_source_files_properties(Pipeline.cc PROPERTIES COMPILE_FLAGS -Wno-sign-compare) # Disable warnings from azure storage API. set_source_files_properties( XJsonBlobBlockCountsMgr.cc XJsonBlobRequest.cc XJsonBlobSink.cc XTableHelper.cc XTableRequest.cc XTableSink.cc PROPERTIES COMPILE_FLAGS "-Wno-unused-value -Wno-reorder -Wno-sign-compare" ) set(WRAPPERS_FOR_OLD_GLIBC_SOURCES wrap_memcpy.c fdelt_chk.c ) add_executable( mdsd ${SOURCES} ${WRAPPERS_FOR_OLD_GLIBC_SOURCES} ) target_link_libraries( mdsd ${CMD_LIB_NAME} ${INPUT_LIB_NAME} ${UTIL_LIB_NAME} ${LOG_LIB_NAME} ${MDSDCFG_LIB_NAME} ${COMM_LIBS} ) install(TARGETS mdsd RUNTIME DESTINATION ${CMAKE_BINARY_DIR}/release/bin ) ================================================ FILE: Diagnostic/mdsd/mdsd/CanonicalEntity.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CanonicalEntity.hh" #include "MdsSchemaMetadata.hh" //#include #include "Utility.hh" #include "Engine.hh" // To get the mdsd config #include "MdsdConfig.hh" #include using std::string; using std::make_pair; // Clone the src entity. Usually used when some operation plans to add columns to a reduced-size // "master" entity, so reserve some extra space just in case. CanonicalEntity::CanonicalEntity(const CanonicalEntity& src) : _timestamp(src._timestamp), _pkey(src._pkey), _rkey(src._rkey), _schemaId(src._schemaId), _srctype(src._srctype) { _entity.reserve(2 + src._entity.size()); // std::for_each(src._entity.cbegin(), src._entity.cend(), [this](const col_t& col){this->CopyAddColumn(col);}); for (const auto & col : src._entity) { AddColumn(col.first, new MdsValue(*(col.second))); } } CanonicalEntity::~CanonicalEntity() { for (col_t col : _entity) { if (col.second) { delete col.second; } } } // AddColumn "owns" the MdsValue* once it's passed in. We can keep it, or move from it and destroy it. void CanonicalEntity::AddColumn(const std::string name, MdsValue* val) { if (name == "PartitionKey") { _pkey = std::move(*(val->strval)); delete val; } else if (name == "RowKey") { _rkey = std::move(*(val->strval)); delete val; } else { _entity.push_back(std::make_pair(name, val)); } } // Add column only if the column name isn't a MetaData column. void CanonicalEntity::AddColumnIgnoreMetaData(const std::string name, MdsValue* val) { if (MdsSchemaMetadata::MetadataColumns.count(name)) { delete val; } else { _entity.push_back(std::make_pair(name, val)); } } MdsValue* CanonicalEntity::Find(const std::string &name) const { for (auto iter : _entity) { if (iter.first == name) { return iter.second; } } return nullptr; } std::ostream& operator<<(std::ostream& os, const CanonicalEntity& ce) { int count = ce._entity.size(); os << "(" << count << " columns, time " << ce.GetPreciseTimeStamp() << ", _pKey "; if (ce._pkey.empty()) { os << "{empty}"; } else { os << ce._pkey; } os << ", _rkey "; if (ce._pkey.empty()) { os << "{empty}"; } else { os << ce._rkey; } os << ", ["; for (auto iter : ce._entity) { os << iter.first << "="; if (iter.second) { os << *(iter.second); } else { os << ""; } if (--count) { os << ", "; } } os << "])"; return os; } std::string CanonicalEntity::GetJsonRow( const std::string& timeGrain, const std::string& tenant, const std::string& role, const std::string& roleInstance) const { const std::string& resourceId = Engine::GetEngine()->GetConfig()->GetResourceId(); if (resourceId.empty()) { throw std::runtime_error("Empty resourceId (OboDirectPartitionField) when a JSON event is requested"); } // Check if this row is for metric or for log. // A metric event must include "CounterName" and "Last" columns. // Its timeGrain shouldn't be empty. bool counterNameExists = false, lastExists = false; for (auto item : _entity) { if (item.first == "CounterName") { counterNameExists = true; } else if (item.first == "Last") { lastExists = true; } } bool isMetricRow = counterNameExists && lastExists && !timeGrain.empty(); return isMetricRow ? GetJsonRowForMetric(resourceId, timeGrain, tenant, role, roleInstance) : GetJsonRowForLog(resourceId); } /* Example return Json string: { "time" : "2016-12-21T01:06:04.9067290Z", "resourceId": "/subscriptions/xxx-xxx-xxx-xxx/resourceGroups/myrg/providers/Microsoft.Compute/VirtualMachines/myvm", "properties" : { "Column1Name": "Column1Value", "Column2Name": "Column2Value", "ColumnNName": "ColumnNValue" }, "category": "user", "level": "info", "operationName": "some_name_depending_on_detected_event_type" } */ std::string CanonicalEntity::GetJsonRowForLog(const std::string& resourceId) const { std::ostringstream oss; oss << "{ \"time\" : \"" << GetPreciseTimeStamp() << "\",\n" " \"resourceId\" : \"" << resourceId << "\",\n" " \"properties\" : {\n"; bool first = true; std::string category = "\"Unknown\"", level = "\"Unknown\"", operationName = "\"Unknown\""; for (auto iter : _entity) { if (first) { first = false; } else { oss << ",\n"; } if (iter.second) { oss << " \"" << iter.first << "\" : " << iter.second->ToJsonSerializedString(); // We consider this event to be from syslog if there's a field named "Facility". // Set the related Azure Monitor required fields (category, level, operationName) accordingly. if (iter.first == "Facility") { category = iter.second->ToJsonSerializedString(); // Let's use syslog facility as Azure Monitor "category". operationName = "\"LinuxSyslogEvent\""; // Change this later as necessary } else if (iter.first == "Severity") { // Let's use syslog severity as Azure Monitor "level". if (iter.second->IsNumeric()) { level = MdsdUtil::GetSyslogSeverityStringFromValue((int)iter.second->ToDouble()); } else { // iter.second->IsString(), which is the case when syslog events are // routed from fluentd's in_syslog & out_mdsd. level = iter.second->ToJsonSerializedString(); } } } } oss << "\n },\n" " \"category\" : " << category << ",\n" " \"level\" : " << level << ",\n" " \"operationName\" : " << operationName << "\n" "}"; return oss.str(); } /* Example return Json string: { "time" : "2016-12-21T01:06:04.9067290Z", "resourceId": "/subscriptions/xxx-xxx-xxx-xxx/resourceGroups/myrg/providers/Microsoft.Compute/VirtualMachines/myvm", "timeGrain" : "PT1M", "dimensions" : { "Tenant": "JsonBlobTestTenantName", "Role": "JsonBlobTestRoleName", "RoleInstance": "JsonBlobTestRoleinstanceName" }, "metricName": "\\Processor\\PercentProcessorTime", "last": 0 } */ std::string CanonicalEntity::GetJsonRowForMetric( const std::string& resourceId, const std::string& timeGrain, const std::string& tenant, const std::string& role, const std::string& roleInstance) const { std::ostringstream oss; oss << "{ \"time\" : \"" << GetPreciseTimeStamp() << "\",\n"; oss << " \"resourceId\" : \"" << resourceId << "\",\n"; oss << " \"timeGrain\" : \"" << timeGrain << "\",\n"; oss << " \"dimensions\": {\n" " \"Tenant\": \"" << tenant << "\",\n" " \"Role\": \"" << role << "\",\n" " \"RoleInstance\": \"" << roleInstance << "\"\n" " }"; static std::unordered_map columnNameTranslations = { { "CounterName", "metricName" }, { "Average", "average" }, { "Minimum", "minimum" }, { "Maximum", "maximum" }, { "Total", "total" }, { "Last", "last" }, { "Count", "count" } }; size_t countOfTranslations = 0; for (const auto & nameValue : _entity) { auto translationPair = columnNameTranslations.find(nameValue.first); if (translationPair != columnNameTranslations.end()) { oss << ",\n \"" << translationPair->second << "\": " << nameValue.second->ToJsonSerializedString(); countOfTranslations++; } } if (columnNameTranslations.size() != countOfTranslations) { std::ostringstream msg; msg << "Dropping invalid CanonicalEntity for metric (missing required column(s)): " << *this; throw std::runtime_error(msg.str()); } oss << "\n}"; return oss.str(); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CanonicalEntity.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CANONICALENTITY_HH_ #define _CANONICALENTITY_HH_ #include #include #include #include "MdsValue.hh" #include "MdsTime.hh" #include "SchemaCache.hh" #include // CanonicalEntity is the internal canonical form of an entity to be handed to MDS. This is a middle ground // between the form in which the information was reported to the daemon (e.g. via JSON event, OMI query, etc.) // and the form required for transmission to the actual MDS data sync (Storage SDK table row object, compressed // BOND blob, etc.) // // CanonicalEntity is the "owner" of the data handed to it. Once you pass an MdsValue* to AddColumn, // you should leave it alone (and, especially, do not delete it). // // CanonicalEntity objects can be copied until they are added to a batch. Once added to a batch, the object // might at any instant (and asynchronously) be handed to a transport, which will convert it to the form // required for transmission to MDS and then delete it. Or it could linger in a local sink and eventually // make its way into some other batch; later rinse repeat. class CanonicalEntity { using col_t = std::pair; friend std::ostream& operator<<(std::ostream& os, const CanonicalEntity& ce); public: enum class SourceType { Ingested, // created from original ingestion Duplicated // created from duplication (e.g. during pipeline) }; CanonicalEntity() : _timestamp(0), _schemaId(0) { _entity.reserve(16); } CanonicalEntity(int n) : _timestamp(0), _schemaId(0) { _entity.reserve(n); } CanonicalEntity(const CanonicalEntity& src); ~CanonicalEntity(); void AddColumn(const std::string name, MdsValue* val); void AddColumnIgnoreMetaData(const std::string name, MdsValue* val); std::string PartitionKey() const { return _pkey; } std::string RowKey() const { return _rkey; } void SetPreciseTime(const MdsTime& t) { _timestamp = t; } const MdsTime& GetPreciseTimeStamp() const { return _timestamp; } const MdsTime& PreciseTime() const { return _timestamp; } time_t GetApproximateTime() const { return _timestamp.to_time_t(); } MdsValue* Find(const std::string &name) const; void SetSchemaId(SchemaCache::IdType id) { _schemaId = id; } SchemaCache::IdType SchemaId() const { return _schemaId; } // Convenience functions void AddColumn(const std::string name, const std::string& val) { AddColumn(name, new MdsValue(val)); } void AddColumn(const std::string name, const char* val) { AddColumn(name, new MdsValue(val)); } // Act a bit like a container, but not all the way typedef std::vector::iterator iterator; typedef std::vector::const_iterator const_iterator; iterator begin() { return _entity.begin(); } const_iterator begin() const { return _entity.begin(); } iterator end() { return _entity.end(); } const_iterator end() const { return _entity.end(); } size_t size() const { return _entity.size(); } // For XJsonBlob & EventHub publishing support // timeGrain should be an empty string for log events, should be ISO8601 duration string (e.g., "PT1M") for metric events. // Caller is responsible to make all conditions true for metric events. // That is, when a non-empty timeGrain is passed (for a metric event), the row should // contain "CounterName" and "Last" columns. std::string GetJsonRow(const std::string& timeGrain, const std::string& tenant, const std::string& role, const std::string& roleInstance) const; void SetSourceType(SourceType t) { _srctype = t; } SourceType GetSourceType() const { return _srctype; } private: std::vector _entity; MdsTime _timestamp; std::string _pkey; std::string _rkey; SchemaCache::IdType _schemaId; SourceType _srctype = SourceType::Ingested; void CopyAddColumn(const col_t& col) { _entity.push_back(std::make_pair(col.first, new MdsValue(*(col.second)))); } std::string GetJsonRowForLog(const std::string& resourceId) const; std::string GetJsonRowForMetric(const std::string& resourceId, const std::string& timeGrain, const std::string& tenant, const std::string& role, const std::string& roleInstance) const; }; std::ostream& operator<<(std::ostream& os, const CanonicalEntity& ce); #endif // _CANONICALENTITY_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgContext.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgContext.hh" #include "CfgCtxError.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include CfgContext* CfgContext::SubContextFactory(const std::string& name) { if (IsErrorContext()) { return new CfgCtxError(this); } const subelementmap_t& subelementmap = GetSubelementMap(); auto iter = subelementmap.find(name); if (iter != subelementmap.end()) { return (iter->second)(this); } else { std::ostringstream oss; oss << '<' << Name() << "> does not define subelement <" << name << '>'; ERROR(oss.str()); return new CfgCtxError(this); } } std::string CfgContext::stringize_attributes(const xmlattr_t& properties) { std::string result; bool first = true; for (const auto& item : properties) { if (!first) { result += ", "; } result += item.first + "=\"" + item.second + "\""; first = false; } return result; } void CfgContext::log_entry(const xmlattr_t& properties) { std::string msg; if (properties.size() > 0) { msg = "Entered " + Name() + " with attribute(s) " + stringize_attributes(properties); } else { msg = "Entered " + Name(); } INFO(msg); } void CfgContext::log_body(const std::string& body) { INFO("Element " + Name() + " has body {" + body + "}"); } bool CfgContext::empty_or_whitespace() { return MdsdUtil::IsEmptyOrWhiteSpace(Body); } CfgContext* CfgContext::Leave() { if (!empty_or_whitespace()) { std::ostringstream oss; oss << '<' << Name() << "> expected empty body; did not expect {" << Body << '}'; WARNING(oss.str()); } return ParentContext; } void CfgContext::warn_if_attributes(const xmlattr_t& properties) { // log_entry(properties); if (!properties.empty()) { WARNING("Expected no attributes"); } } void CfgContext::INFO(const std::string& msg) { Config->AddMessage(MdsdConfig::info, msg); } void CfgContext::WARNING(const std::string& msg) { Config->AddMessage(MdsdConfig::warning, msg); } void CfgContext::ERROR(const std::string& msg) { Config->AddMessage(MdsdConfig::error, msg); } void CfgContext::FATAL(const std::string& msg) { Config->AddMessage(MdsdConfig::fatal, msg); } void CfgContext::parse_singleton_attribute( const std::string & itemname, const std::string & itemval, const std::string & attrname, std::string& attrval ) { if (attrname != itemname) { return; } if (attrval.empty()) { attrval = itemval; } else { ERROR("\"" + attrname + "\" can appear in <" + Name() + "> only once."); } } void CfgContext::fatal_if_no_attributes( const std::string & attrname, const std::string & attrval ) { if (attrval.empty()) { FATAL("<" + Name() + "> requires \"" + attrname + "\" attribute."); } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgContext.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCONTEXT_HH_ #define _CFGCONTEXT_HH_ #include #include #include #include "SaxParserBase.hh" class CfgContext; class MdsdConfig; /// /// Maps from a (permitted) subelement name to the function which returns an appropriate new context. /// typedef std::map > subelementmap_t; /// /// XML element attribute list /// typedef SaxParserBase::AttributeMap xmlattr_t; /// /// const iterator on an XML attribute list /// typedef SaxParserBase::AttributeMap::const_iterator xmlattr_iter_t; /// /// This pure virtual class is really an Interface class for all parsing context classes. /// class CfgContext { public: virtual ~CfgContext() {} /// /// Asks the current context to construct a child context given the name of a subelement. /// If the current context doesn't permit a subelement of that name, a new Error context is /// returned. /// CfgContext* SubContextFactory(const std::string& name); /// /// Provides attributes of the just-entered XML element to the context for the element. /// virtual void Enter(const xmlattr_t& properties) = 0; /// /// Provides the body of the current XML element to the context for the element. May be /// called for each chunk of characters found between subelements. By default, just /// accumulate chunks into the Body member variable. /// virtual void HandleBody(const std::string& body) { Body += body; } /// /// Provides the CDATA text of the current XML element to the context of the element. /// By default, just accumulate chunks into the CdataText member variable. /// Example for CDATA: /// virtual void HandleCdata(const std::string& cdata) { Body += cdata; } /// /// Invoked when the parser is leaving the element. The context should finish its work /// (e.g. finalize changes to the MdsdConfig object). Once this member is called, the class /// instance is ready to be destroyed. Base class implementation warns if the body is /// non-empty but otherwise does nothing. /// virtual CfgContext* Leave(); /// Fetch the printable name for the context. virtual const std::string& Name() const = 0; /// Fetch the context map of permitted subelements. virtual const subelementmap_t& GetSubelementMap() const = 0; /// True if the parse is in "error" state virtual bool IsErrorContext() const { return false; } void INFO(const std::string& msg); void WARNING(const std::string& msg); void ERROR(const std::string& msg); void FATAL(const std::string& msg); private: /// /// Disallow default constructor. /// CfgContext() : ParentContext(NULL), Config(NULL) {} /// /// Convert a list of XML SaxParser attributes to a printable string /// /// The attribute list for the element std::string stringize_attributes(const xmlattr_t& properties); protected: /// The context object for the XML element that contains this one. CfgContext* const ParentContext; // Should provide an accessor to allow derived classes to call methods through this pointer, with the // pointer itself remaining private. MdsdConfig* const Config; /// Accumulated body of the element std::string Body; /// /// Creates a context representing a particular element an XML document. Knows how to handle attributes /// of the element and any content (body text). Knows what sub-elements are legal. /// /// A pointer to the parent (enveloping) context. CfgContext(CfgContext* previousContext) : ParentContext(previousContext), Config(previousContext->Config) {} /// /// Creates a context representing the root element an XML document. /// /// A pointer to the parent (enveloping) context. CfgContext(MdsdConfig* config) : ParentContext(NULL), Config(config) {} /// /// Add an Info message recording entry into a new element /// /// The attribute list for the element void log_entry(const xmlattr_t& properties); /// /// Add an Info message recording a body-chunk for the current element /// /// The body text found within the element void log_body(const std::string& body); /// /// Return true if the accumulated body of the element is empty or whitespace /// bool empty_or_whitespace(); /// /// Add a warning message if any attributes were specified /// /// The attribute list for the element void warn_if_attributes(const xmlattr_t& properties); void parse_singleton_attribute(const std::string & itemname, const std::string & itemval, const std::string & attrname, std::string& attrval); void fatal_if_no_attributes(const std::string & attrname, const std::string & attrval); void warn_if_attribute_unexpected(const std::string & attrname) { WARNING("Ignoring unexpected <" + Name() + "> attribute \"" + attrname + "\""); } void fatal_if_impossible_subelement() { FATAL("Found <" + Name() + "> in <" + ParentContext->Name() + ">; that can't happen."); } }; #endif //_CFGCONTEXT_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxAccounts.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxAccounts.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "Utility.hh" #include "AzureUtility.hh" #include "cryptutil.hh" #include "Trace.hh" #include ///////// CfgCtxAccounts subelementmap_t CfgCtxAccounts::_subelements = { { "Account", [](CfgContext* parent) -> CfgContext* { return new CfgCtxAccount(parent); } }, { "SharedAccessSignature", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSAS(parent); } }, }; std::string CfgCtxAccounts::_name = "Accounts"; CfgContext* CfgCtxAccounts::Leave() { return CfgContext::Leave(); } ///////// CfgCtxAccount void CfgCtxAccount::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxAccount::Enter"); std::string moniker, account, sharedKey, decryptKeyPath, blobEndpoint, tableEndpoint; bool makeDefault = false; for (const auto& item : properties) { if (item.first == "moniker") { if (moniker.empty()) { moniker = item.second; } else { ERROR("\"moniker\" can appear in only once"); } } else if (item.first == "key") { sharedKey = item.second; } else if (item.first == "decryptKeyPath") { decryptKeyPath = item.second; } else if (item.first == "account") { account = item.second; size_t len = account.length(); // Squeeze any embedded spaces from the account account.erase(std::remove(account.begin(), account.end(), ' '), account.end()); if (len != account.length()) { WARNING("Account cannot contain spaces; blanks were removed"); } } else if (item.first == "isDefault") { makeDefault = MdsdUtil::to_bool(item.second); } else if (item.first == "blobEndpoint") { blobEndpoint = item.second; } else if (item.first == "tableEndpoint") { tableEndpoint = item.second; } else { WARNING("Ignoring unexpected attribute \"" + item.first + "\""); } } if (moniker.empty()) { FATAL(" requires \"moniker\" attribute"); } else { // Create the correct credential object based on the attributes // Must be shared key if (account.empty()) { ERROR("\"account\" must be set for shared key moniker"); } else if (sharedKey.empty()) { ERROR("\"key\" must be set for shared key moniker"); } else { if (!decryptKeyPath.empty()) { try { sharedKey = cryptutil::DecodeAndDecryptString(decryptKeyPath, sharedKey); } catch (const std::exception& e) { ERROR(std::string("Storage key decryption (using private key at ").append(decryptKeyPath).append(") failed with the message: ").append(e.what())); return; } catch (...) { ERROR("Unknown exception thrown when decrypting storage key"); return; } } auto creds = new CredentialType::SharedKey(moniker, account, sharedKey); if (!blobEndpoint.empty()) { creds->BlobUri(blobEndpoint); } if (!tableEndpoint.empty()) { creds->TableUri(tableEndpoint); } /* Validate storage account for table access. */ try { MdsdUtil::ValidateStorageCredentialForTable(creds->GetConnectionStringOnly(Credentials::ServiceType::XTable)); Config->AddCredentials(creds, makeDefault); } catch (const std::exception& e) { ERROR(std::string("Storage credential validation for table storage failed: ").append(e.what())); } catch (...) { ERROR("Unknown exception thrown when validating storage credential for table storage"); } } } } subelementmap_t CfgCtxAccount::_subelements; std::string CfgCtxAccount::_name = "Account"; ///////// CfgCtxSAS void CfgCtxSAS::Enter(const xmlattr_t& properties) { std::string moniker, account, token, decryptKeyPath, blobEndpoint, tableEndpoint; bool makeDefault = false; for (const auto& item : properties) { if (item.first == "moniker") { moniker = item.second; } else if (item.first == "key") { token = item.second; MdsdUtil::ReplaceSubstring(token, "&", "&"); } else if (item.first == "decryptKeyPath") { decryptKeyPath = item.second; } else if (item.first == "account") { account = item.second; size_t len = account.length(); // Squeeze any embedded spaces from the account account.erase(std::remove(account.begin(), account.end(), ' '), account.end()); if (len != account.length()) { WARNING("Account cannot contain spaces; blanks were removed"); } } else if (item.first == "blobEndpoint") { blobEndpoint = item.second; } else if (item.first == "tableEndpoint") { tableEndpoint = item.second; } else if (item.first == "isDefault") { makeDefault = MdsdUtil::to_bool(item.second); } else { WARNING("Ignoring unexpected attribute \"" + item.first + "\""); } } if (moniker.empty()) { FATAL("\"moniker\" must be specified"); } else if (account.empty()) { FATAL("\"account\" must be specified"); } else if (token.empty()) { FATAL("\"key\" must be specified"); } else { if (!decryptKeyPath.empty()) { try { token = cryptutil::DecodeAndDecryptString(decryptKeyPath, token); MdsdUtil::ReplaceSubstring(token, "&", "&"); } catch (const std::exception& e) { ERROR(std::string("Storage account SAS token decryption (using private key at ").append(decryptKeyPath).append(") failed with the message: ").append(e.what())); return; } catch (...) { ERROR("Unknown exception thrown when decrypting storage account SAS token"); return; } } try { auto creds = new CredentialType::SAS(moniker, account, token); if (!blobEndpoint.empty()) { creds->BlobUri(blobEndpoint); } if (!tableEndpoint.empty()) { creds->TableUri(tableEndpoint); } if (creds->IsAccountSas()) { /* Validate storage account for table access (same as above in shared key) * only if it's an account SAS. */ MdsdUtil::ValidateStorageCredentialForTable(creds->GetConnectionStringOnly(Credentials::ServiceType::XTable)); } Config->AddCredentials(creds, makeDefault); } catch (MdsdUtil::MdsdInvalidSASException& e) { ERROR(std::string("Invalid SAS token given. Reason: ").append(e.what())); } catch (const std::exception& e) { ERROR(std::string("Storage credential validation for table storage failed: ").append(e.what())); } catch (...) { ERROR("Unknown exception thrown when validating storage credential for table storage"); } } } subelementmap_t CfgCtxSAS::_subelements; std::string CfgCtxSAS::_name = "SharedAccessSignature"; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxAccounts.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXACCOUNTS_HH_ #define _CFGCTXACCOUNTS_HH_ #include "CfgContext.hh" class CfgCtxAccounts : public CfgContext { public: CfgCtxAccounts(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxAccounts() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* Leave() override; private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxAccount : public CfgContext { public: CfgCtxAccount(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxAccount() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxSAS : public CfgContext { public: CfgCtxSAS(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSAS() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXACCOUNTS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxDerived.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxDerived.hh" #include "CfgCtxError.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include "StoreType.hh" #include "PipeStages.hh" #include "LADQuery.hh" #include "Priority.hh" #include "Trace.hh" #include "EventType.hh" ////////////////// CfgCtxDerived subelementmap_t CfgCtxDerived::_subelements = { { "DerivedEvent", [](CfgContext* parent) -> CfgContext* { return new CfgCtxDerivedEvent(parent); } } }; std::string CfgCtxDerived::_name = "DerivedEvents"; ////////////////// CfgCtxDerivedEvent void CfgCtxDerivedEvent::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxDerivedEvent::Enter"); std::string eventName, account, source, durationString; MdsTime duration { 0 }; bool NoPerNDay = false, isFullName = false; Priority priority { "Normal" }; _task = nullptr; _isOK = true; _storeType = StoreType::XTable; _doSchemaGeneration = true; for (const auto & iter : properties) { if (iter.first == "eventName") { if (MdsdUtil::NotValidName(iter.second)) { ERROR("Invalid eventName attribute"); } else { eventName = iter.second; } } else if (iter.first == "priority") { if (! priority.Set(iter.second)) { WARNING("Ignoring unknown priority \"" + iter.second + "\""); } } else if (iter.first == "account") { if (MdsdUtil::NotValidName(iter.second)) { ERROR("Invalid account attribute"); } else { account = iter.second; } } else if (iter.first == "dontUsePerNDayTable") { NoPerNDay = MdsdUtil::to_bool(iter.second); } else if (iter.first == "isFullName") { isFullName = MdsdUtil::to_bool(iter.second); } else if (iter.first == "duration") { MdsTime requestedDuration = MdsTime::FromIS8601Duration(iter.second); if (!requestedDuration) { ERROR("Invalid duration attribute"); _isOK = false; } else { duration = requestedDuration; durationString = iter.second; } } else if (iter.first == "storeType") { _storeType = StoreType::from_string(iter.second); _doSchemaGeneration = StoreType::DoSchemaGeneration(_storeType); } else if (iter.first == "source") { if (MdsdUtil::NotValidName(iter.second)) { ERROR("Invalid account attribute"); } else { source = iter.second; } } else { WARNING("Ignoring unexpected attribute " + iter.first); } } if (!duration) { ERROR("The duration attribute is required"); _isOK = false; } if (!_isOK) { return; } try { // Build target on the stack, move it into the DerivedTask auto target = MdsEntityName { eventName, NoPerNDay, Config, account, _storeType, isFullName }; _task = new DerivedEvent(Config, std::move(target), priority, duration, source); // Centrally-stored events implicitly have Identity columns added to them as // defined in the element. Add them first thing so they're available // to subsequent stages (if any). if (_storeType != StoreType::Local) { _task->AddStage(new Pipe::Identity(Config->GetIdentityVector())); } Config->AddMonikerEventInfo(account, eventName, _storeType, source, mdsd::EventType::DerivedEvent); Config->SetDurationForEventName(eventName, durationString); } catch (const std::exception& ex) { ERROR(ex.what()); _isOK = false; return; } catch (...) { FATAL("Unknown exception; skipping"); _isOK = false; return; } } CfgContext* CfgCtxDerivedEvent::Leave() { Trace trace(Trace::ConfigLoad, "CfgCtxDerivedEvent::Leave"); if(_task) { // If not local, add a stage to push metadata into MDS. Derived queries should produce results with // the same schema each time. Doing an doesn't change that. if (_doSchemaGeneration && _storeType != StoreType::Local) { _task->AddStage(new Pipe::BuildSchema(Config, _task->Target(), true)); } // Find/make the batch for this task; add a final pipeline stage to write to that batch; // add the task to the set of tasks in this config. Batch *batch = Config->GetBatch(_task->Target(), _task->FlushInterval()); if (batch) { _task->AddStage(new Pipe::BatchWriter(batch, Config->GetIdentityVector(), Config->PartitionCount(), _storeType)); Config->AddTask(_task); } else { ERROR("Configuration error(s) detected; dropping this DerivedEvent."); delete _task; } } return ParentContext; } const subelementmap_t& CfgCtxDerivedEvent::GetSubelementMap() const { if (_isOK) { return _subelements; } else { return CfgCtxError::subelements; } } subelementmap_t CfgCtxDerivedEvent::_subelements { { "LADQuery", [](CfgContext* parent) -> CfgContext* { return new CfgCtxLADQuery(parent); } } }; std::string CfgCtxDerivedEvent::_name = "DerivedEvent"; ////////////////// CfgCtxLADEvent void CfgCtxLADQuery::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxLADQuery::Enter"); std::string valueAttrName, nameAttrName, partitionKey, uuid; CfgCtxDerivedEvent* query = dynamic_cast(ParentContext); if (!query) { ERROR(" is not a valid subelement of <" + ParentContext->Name() + ">"); return; } // Bail if parent didn't parse right or didn't build an OmiTask instance if (! (query->isOK() && query->GetTask())) { return; } for (const auto& item : properties) { if (item.first == "columnValue") { valueAttrName = item.second; } else if (item.first == "columnName") { nameAttrName = item.second; } else if (item.first == "partitionKey") { partitionKey = item.second; } else if (item.first == "instanceID") { uuid = item.second; } else { WARNING("Ignoring unexpected attribute " + item.first); } } if (valueAttrName.empty() || nameAttrName.empty() || partitionKey.empty()) { ERROR("Missing one or more required attributes (columnValue, columnName, partitionKey)"); return; } // An empty or unset uuid attribute is permitted (and meaningful) auto task = query->GetTask(); task->AddStage(new Pipe::LADQuery(std::move(valueAttrName), std::move(nameAttrName), std::move(partitionKey), std::move(uuid))); // Centrally-stored events implicitly have Identity columns added to them as // defined in the element. The LADQuery stage strips them off; // we should put them back in. if (! (query->isStoredLocally()) ) { task->AddStage(new Pipe::Identity(Config->GetIdentityVector())); } query->SuppressSchemaGeneration(); // LAD queries don't generate entries in SchemasTable } subelementmap_t CfgCtxLADQuery::_subelements; std::string CfgCtxLADQuery::_name = "LADQuery"; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxDerived.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXDERIVED_HH_ #define _CFGCTXDERIVED_HH_ #include "CfgContext.hh" #include "DerivedEvent.hh" #include "StoreType.hh" class CfgCtxDerived : public CfgContext { public: CfgCtxDerived(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxDerived() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxDerivedEvent : public CfgContext { public: CfgCtxDerivedEvent(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxDerivedEvent() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const; void Enter(const xmlattr_t& properties); CfgContext* Leave(); bool isOK() const { return _isOK; } bool isStoredLocally() const { return StoreType::Local == _storeType; } DerivedEvent * GetTask() const { return _task; } void SuppressSchemaGeneration() { _doSchemaGeneration = false; } private: static subelementmap_t _subelements; static std::string _name; DerivedEvent *_task; bool _isOK; StoreType::Type _storeType; bool _doSchemaGeneration; }; class CfgCtxLADQuery : public CfgContext { public: CfgCtxLADQuery(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxLADQuery() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXDERIVED_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEnvelope.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxEnvelope.hh" #include "MdsdConfig.hh" #include "Utility.hh" /////// CfgCtxEnvelope subelementmap_t CfgCtxEnvelope::_subelements = { { "Field", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEnvelopeField(parent); } }, { "Extension", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEnvelopeExtension(parent); } }, }; std::string CfgCtxEnvelope::_name = "EnvelopeSchema"; /////// CfgCtxEnvelopeField subelementmap_t CfgCtxEnvelopeField::_subelements; std::string CfgCtxEnvelopeField::_name = "Field"; void CfgCtxEnvelopeField::SetFieldValueIfUnset(CfgCtxEnvelopeField::ValueSource source, const std::string & value) { if (Source != ValueSource::none) { WARNING(std::string("Cannot specify multiple sources for this value; using '") + FieldValue + "'"); } else { FieldValue = value; Source = source; } } void CfgCtxEnvelopeField::Enter(const xmlattr_t& properties) { Source = ValueSource::none; for (const auto& item : properties) { if (item.first == "name") { FieldName = item.second; } else if (item.first == "envariable") { try { SetFieldValueIfUnset(ValueSource::environment, MdsdUtil::GetEnvironmentVariable(item.second)); } catch (std::exception & ex) { WARNING(ex.what()); SetFieldValueIfUnset(ValueSource::environment, std::string()); } } else if (item.first == "useComputerName") { SetFieldValueIfUnset(ValueSource::agentIdent, Config->AgentIdentity()); } else { ERROR(" ignoring unexpected attribute " + item.first); } } if (FieldName.empty()) { ERROR(" missing required 'name' attribute"); } } void CfgCtxEnvelopeField::HandleBody(const std::string& body) { if (Source == ValueSource::environment || Source == ValueSource::agentIdent) { WARNING(std::string("Cannot specify multiple sources for this value; using '") + FieldValue + "'"); } else { FieldValue += body; Source = ValueSource::configFile; } } CfgContext* CfgCtxEnvelopeField::Leave() { if (!FieldName.empty()) { if (Source == ValueSource::none) { WARNING("No value supplied for this column; using empty string"); } Config->AddEnvelopeColumn(std::move(FieldName), std::move(FieldValue)); } return ParentContext; } /////// CfgCtxEnvelopeExtension subelementmap_t CfgCtxEnvelopeExtension::_subelements = { { "Field", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEnvelopeField(parent); } }, }; std::string CfgCtxEnvelopeExtension::_name = "Extension"; void CfgCtxEnvelopeExtension::Enter(const xmlattr_t& properties) { for (const auto& item : properties) { if (item.first == "name") { ExtensionName = item.second; } else { ERROR(" ignoring unexpected attribute " + item.first); } } if (ExtensionName.empty()) { ERROR(" missing required 'name' attribute"); } } // vim: set tabstop=4 softtabstop=4 shiftwidth=4 noexpandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEnvelope.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXENVELOPE_HH_ #define _CFGCTXENVELOPE_HH_ #include "CfgContext.hh" class LocalSink; class CfgCtxEnvelope : public CfgContext { public: CfgCtxEnvelope(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEnvelope() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxEnvelopeField : public CfgContext { public: enum ValueSource { none, environment, agentIdent, configFile }; CfgCtxEnvelopeField(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEnvelopeField() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); virtual void HandleBody(const std::string& body); CfgContext* Leave(); void SetFieldValueIfUnset(ValueSource, const std::string &); private: std::string FieldName; std::string FieldValue; ValueSource Source; static subelementmap_t _subelements; static std::string _name; }; class CfgCtxEnvelopeExtension : public CfgContext { public: CfgCtxEnvelopeExtension(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEnvelopeExtension() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: std::string ExtensionName; static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXENVELOPE_HH_ // vim: set tabstop=4 softtabstop=4 shiftwidth=4 noexpandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxError.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxError.hh" subelementmap_t CfgCtxError::subelements; std::string CfgCtxError::name = "(A previous error was detected)"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxError.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXERROR_HH_ #define _CFGCTXERROR_HH_ #include "CfgContext.hh" /// /// Once an unexpected element is found while parsing the config file, this class sets up /// an "error detected" context that is propagated until the parse leaves the unexpected /// element. Note that any Insert elements are ignored by this context. /// class CfgCtxError : public CfgContext { public: CfgCtxError(CfgContext* previousContext) : CfgContext(previousContext) {} virtual ~CfgCtxError() { } virtual const std::string& Name() const { return name; } const subelementmap_t& GetSubelementMap() const { return subelements; } // We're deliberately silent on the attributes and body of elements while we're in an error state void Enter(const xmlattr_t&) {}; void HandleBody(const std::string&) {}; CfgContext* Leave() { return ParentContext; } /// /// An empty list of legal subelements. Any context can return this from GetSubelementMap() if the /// element has errors that block usage. /// static subelementmap_t subelements; /// True if the parse is in "error" state virtual bool IsErrorContext() const { return true; } private: static std::string name; }; #endif //_CFGCTXERROR_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEtw.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxEtw.hh" #include "MdsdConfig.hh" #include "CfgCtxParser.hh" #include "LocalSink.hh" #include "Subscription.hh" #include "PipeStages.hh" #include "EtwEvent.hh" #include "EventType.hh" ////////////////// CfgCtxEtwProviders subelementmap_t CfgCtxEtwProviders::s_subelements = { { "EtwProvider", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxEtwProvider(parent); } } }; std::string CfgCtxEtwProviders::s_name = "EtwProviders"; ////////////////// CfgCtxEtwProvider std::string CfgCtxEtwProvider::s_name = "EtwProvider"; subelementmap_t CfgCtxEtwProvider::s_subelements = { {"Event", [] (CfgContext* parent) -> CfgContext * { return new CfgCtxEtwEvent(parent); } } }; void CfgCtxEtwProvider::Enter(const xmlattr_t& properties) { CfgCtx::CfgCtxParser parser(this); if (!parser.ParseEtwProvider(properties)) { return; } m_guid = parser.GetGuid(); m_priority = parser.GetPriority(); if (parser.HasStoreType()) { m_storeType = parser.GetStoreType(); } } CfgContext* CfgCtxEtwProvider::Leave() { return ParentContext; } ////////////////// CfgCtxEtwEvent subelementmap_t CfgCtxEtwEvent::s_subelements; std::string CfgCtxEtwEvent::s_name = "Event"; void CfgCtxEtwEvent::Enter(const xmlattr_t& properties) { CfgCtx::CfgCtxParser parser(this); if (!parser.ParseEvent(properties, CfgCtx::EventType::EtwEvent)) { return; } CfgCtxEtwProvider *parent = dynamic_cast(ParentContext); if (!parent) { FATAL("Found <" + s_name + "> in <" + ParentContext->Name() + ">; that can't happen."); return; } auto guidstr = parent->GetGuid(); if (guidstr.empty()) { ERROR("<" + s_name + "> missed required GUID attribute."); return; } if (parser.HasStoreType()) { m_storeType = parser.GetStoreType(); } else { m_storeType = parent->GetStoreType(); } if (StoreType::None == m_storeType) { m_storeType = StoreType::XTable; } Priority priority; if (parser.HasPriority()) { priority = parser.GetPriority(); } else { priority = parent->GetPriority(); } m_eventId = parser.GetEventId(); // for ETW, use local table name as LocalSink source. std::string source = EtwEvent::BuildLocalTableName(guidstr, m_eventId); m_sink = LocalSink::Lookup(source); if (!m_sink) { m_sink = new LocalSink(source); m_sink->AllocateSchemaId(); } bool isNoPerNDay = parser.IsNoPerNDay(); std::string account = parser.GetAccount(); std::string eventName = parser.GetEventName(); time_t interval = parser.GetInterval(); try { auto target = MdsEntityName { eventName, isNoPerNDay, Config, account, m_storeType }; m_subscription = new Subscription(m_sink, std::move(target), priority, MdsTime(interval)); if (StoreType::Local != m_storeType) { m_subscription->AddStage(new Pipe::Identity(Config->GetIdentityVector())); } Config->AddMonikerEventInfo(account, eventName, m_storeType, source, mdsd::EventType::EtwEvent); } catch(const std::invalid_argument& ex) { ERROR(ex.what()); return; } catch(...) { FATAL("Unknown exception; skipping."); return; } } CfgContext* CfgCtxEtwEvent::Leave() { if (!m_subscription) { return ParentContext; } if (StoreType::XTable == m_storeType) { m_subscription->AddStage(new Pipe::BuildSchema(Config, m_subscription->target(), true)); } Batch* batch = Config->GetBatch(m_subscription->target(), m_subscription->Duration()); if (batch) { m_subscription->AddStage(new Pipe::BatchWriter(batch, Config->GetIdentityVector(), Config->PartitionCount(), m_storeType)); Config->AddTask(m_subscription); } else { ERROR("Unable to create routing for " + s_name + " id=" + std::to_string(m_eventId)); } return ParentContext; } ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEtw.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXETW_HH_ #define _CFGCTXETW_HH_ #include "CfgContext.hh" #include "CfgCtxError.hh" #include "StoreType.hh" #include "Priority.hh" class CfgCtxEtwProviders : public CfgContext { public: CfgCtxEtwProviders(CfgContext * config) : CfgContext(config) {} virtual ~CfgCtxEtwProviders() {} virtual const std::string& Name() const { return s_name; } static const std::string& XmlName() { return s_name; } virtual const subelementmap_t& GetSubelementMap() const { return s_subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t s_subelements; static std::string s_name; }; class CfgCtxEtwProvider : public CfgContext { public: CfgCtxEtwProvider(CfgContext * config) : CfgContext(config) {} virtual ~CfgCtxEtwProvider() {} virtual const std::string& Name() const { return s_name; } static const std::string& XmlName() { return s_name; } virtual const subelementmap_t& GetSubelementMap() const { return (m_guid.empty()? CfgCtxError::subelements : s_subelements); } void Enter(const xmlattr_t& properties); CfgContext* Leave(); std::string GetGuid() const { return m_guid; } StoreType::Type GetStoreType() const { return m_storeType; } Priority GetPriority() const { return m_priority; } private: static subelementmap_t s_subelements; static std::string s_name; std::string m_guid; StoreType::Type m_storeType = StoreType::None; Priority m_priority; }; class CfgCtxEtwEvent : public CfgContext { public: CfgCtxEtwEvent(CfgContext * config) : CfgContext (config) {} virtual ~CfgCtxEtwEvent() {} virtual const std::string& Name() const { return s_name; } static const std::string& XmlName() { return s_name; } virtual const subelementmap_t& GetSubelementMap() const { return s_subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); private: static subelementmap_t s_subelements; static std::string s_name; StoreType::Type m_storeType = StoreType::None; int m_eventId = -1; class LocalSink* m_sink = nullptr; class Subscription* m_subscription = nullptr; }; #endif // _CFGCTXETW_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEventAnnotations.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxEventAnnotations.hh" #include "MdsdConfig.hh" #include "ConfigParser.hh" #include "MdsTime.hh" #include "Trace.hh" #include "CfgOboDirectConfig.hh" #include "MdsdEventCfg.hh" #include "EventPubCfg.hh" #include "Utility.hh" #include "cryptutil.hh" ///////// CfgCtxEventAnnotations subelementmap_t CfgCtxEventAnnotations::_subelements = { { "EventStreamingAnnotation", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEventAnnotation(parent); } } }; std::string CfgCtxEventAnnotations::_name = "EventStreamingAnnotations"; void CfgCtxEventAnnotations::SetEventType( const std::string & itemname, EventAnnotationType::Type type ) { if (itemname.empty()) { ERROR("<" + Name() + "> invalid empty itemname attribute"); return; } // if duplicate, report error auto item = _eventmap.find(itemname); if (item != _eventmap.end()) { if (item->second & type) { ERROR("<" + Name() + "> itemname " + itemname + " already defined for type " + std::to_string(type)); } } _eventmap[itemname] = static_cast< EventAnnotationType::Type>(_eventmap[itemname] | type); } CfgContext* CfgCtxEventAnnotations::Leave() { if (_eventmap.size() > 0) { Config->GetMdsdEventCfg()->SetEventAnnotationTypes(std::move(_eventmap)); } return ParentContext; } ///////// CfgCtxEventAnnotation subelementmap_t CfgCtxEventAnnotation::_subelements = { { "EventPublisher", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEPA(parent); } }, { "OnBehalf", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOnBehalf(parent, dynamic_cast(parent)->_itemName); } } }; std::string CfgCtxEventAnnotation::_name = "EventStreamingAnnotation"; void CfgCtxEventAnnotation::Enter(const xmlattr_t& properties) { const std::string attrName = "name"; for (const auto & item : properties) { if (attrName == item.first) { parse_singleton_attribute(item.first, item.second, attrName, _itemName); } else { warn_if_attribute_unexpected(item.first); } } fatal_if_no_attributes(attrName, _itemName); } void CfgCtxEventAnnotation::SetEventType(EventAnnotationType::Type eventType) { auto parentObj = dynamic_cast(ParentContext); if (!parentObj) { fatal_if_impossible_subelement(); return; } parentObj->SetEventType(_itemName, eventType); } void CfgCtxEventAnnotation::SetEventSasKey( std::string&& saskey ) { if (saskey.empty()) { return; } // EventHubs publisher requires resourceId defined for Shoebox V2. // If another scenario needs to be supported, this code may need to be changed as well. if (Config->GetResourceId().empty()) { ERROR("<" + Name() + ">: OboDirectPartitionField resourceId is missing, when Shoebox V2 EventHubs publisher needs one."); return; } try { Config->GetEventPubCfg()->AddAnnotationKey(_itemName, std::move(saskey)); } catch(const std::exception& ex) { ERROR("<" + Name() + "> exception: " + ex.what()); } } ///////// CfgCtxEPA subelementmap_t CfgCtxEPA::_subelements = { { "Content", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEPAContent(parent); } }, { "Key", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEPAKey(parent); } } }; void CfgCtxEPA::Enter(const xmlattr_t& properties) { warn_if_attributes(properties); // Set the event type (EventPublisher) for the event (this Publisher element's parent's name attribute) in the EventAnnotations (grandparent) element's event type map auto parentObj = dynamic_cast(ParentContext); if (!parentObj) { fatal_if_impossible_subelement(); return; } parentObj->SetEventType(EventAnnotationType::Type::EventPublisher); } void CfgCtxEPA::SetEventSasKey( std::string&& saskey ) { if (saskey.empty()) { return; } auto parentObj = dynamic_cast(ParentContext); if (!parentObj) { fatal_if_impossible_subelement(); return; } parentObj->SetEventSasKey(std::move(saskey)); } std::string CfgCtxEPA::_name = "EventPublisher"; ///////// CfgCtxEPAContent subelementmap_t CfgCtxEPAContent::_subelements; std::string CfgCtxEPAContent::_name = "Content"; ///////// CfgCtxEPAKey subelementmap_t CfgCtxEPAKey::_subelements; std::string CfgCtxEPAKey::_name = "Key"; void CfgCtxEPAKey::Enter(const xmlattr_t& properties) { // Decrypt key path attribute is optional const std::string & decryptKeyPathAttr = "decryptKeyPath"; for (const auto & item : properties) { if (decryptKeyPathAttr == item.first) { parse_singleton_attribute(item.first, item.second, decryptKeyPathAttr, _decryptKeyPath); } else { warn_if_attribute_unexpected(item.first); } } } CfgContext* CfgCtxEPAKey::Leave() { if (Body.empty()) { return ParentContext; } auto parentObj = dynamic_cast(ParentContext); if (!parentObj) { fatal_if_impossible_subelement(); return ParentContext; } if (_decryptKeyPath.empty()) { auto escapedConnStr = MdsdUtil::UnquoteXmlAttribute(Body); parentObj->SetEventSasKey(std::move(escapedConnStr)); } else { if (!MdsdUtil::IsRegFileExists(_decryptKeyPath)) { ERROR("Cannot find decrypt key path " + _decryptKeyPath); } else { try { auto decryptedSas = cryptutil::DecodeAndDecryptString(_decryptKeyPath, Body); parentObj->SetEventSasKey(std::move(decryptedSas)); } catch(const std::exception & ex) { ERROR("EventPublisher SAS key decryption using private key file '" + _decryptKeyPath + "' failed: " + ex.what()); } } } return ParentContext; } /////////// CfgCtxOnBehalf subelementmap_t CfgCtxOnBehalf::_subelements = { { "Content", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOnBehalfContent(parent, dynamic_cast(parent)->_eventName); } } }; std::string CfgCtxOnBehalf::_name = "OnBehalf"; void CfgCtxOnBehalf::Enter(const xmlattr_t& properties) { std::string valDirectMode; const std::string attrDirectMode = "directMode"; for (const auto& item : properties) { if (attrDirectMode == item.first) { parse_singleton_attribute(item.first, item.second, attrDirectMode, valDirectMode); } else { warn_if_attribute_unexpected(item.first); } } fatal_if_no_attributes(attrDirectMode, valDirectMode); if (valDirectMode != "true") { ERROR("<" + Name() + "> supports attribute " + attrDirectMode + "=\"true\" only currently"); } // Set the event type (OnBehalf) for the event (this Publisher element's parent's name attribute) in the EventAnnotations (grandparent) element's event type map auto parentObj = dynamic_cast(ParentContext); if (!parentObj) { fatal_if_impossible_subelement(); return; } parentObj->SetEventType(EventAnnotationType::Type::OnBehalf); } /////////// CfgCtxOnBehalfContent subelementmap_t CfgCtxOnBehalfContent::_subelements = { { "Config", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOnBehalfConfig(parent, dynamic_cast(parent)->_eventName); } } // This is a trick to handle the CDATA XML content as a subelement... }; std::string CfgCtxOnBehalfContent::_name = "Content"; void CfgCtxOnBehalfContent::Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* CfgCtxOnBehalfContent::Leave() { if (Body.empty()) { ERROR("<" + Name() +"> must have a body (CDATA), but it's empty"); } else { // Trick: Parse the cdata (another XML) by treating it as a subelement... ConfigParser xmlCdataParser(this, Config); xmlCdataParser.Parse(Body); } return ParentContext; } ///////////// CfgCtxOnBehalfConfig (XML in CDATA of CfgCtxOnBehalfContent...) subelementmap_t CfgCtxOnBehalfConfig::_subelements; std::string CfgCtxOnBehalfConfig::_name = "Config"; void CfgCtxOnBehalfConfig::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxOnBehalfConfig::Enter"); auto oboDirectConfig = std::make_shared(); for (const auto& item : properties) { if (item.first == "onBehalfFields") // Not used by mdsd yet { oboDirectConfig->onBehalfFields = item.second; } else if (item.first == "containerSuffix") // Not used by mdsd yet { oboDirectConfig->containerSuffix = item.second; } else if (item.first == "primaryPartitionField") { oboDirectConfig->primaryPartitionField = item.second; } else if (item.first == "partitionFields") { oboDirectConfig->partitionFields = item.second; } else if (item.first == "onBehalfReplaceFields") // Not used by mdsd yet { oboDirectConfig->onBehalfReplaceFields = item.second; } else if (item.first == "excludeFields") // Not used by mdsd yet { oboDirectConfig->excludeFields = item.second; } else if (item.first == "timePeriods") { if (MdsTime::FromIS8601Duration(item.second).to_time_t() == 0) { ERROR("Invalid ISO8601 time duration is given: " + item.second); } else { oboDirectConfig->timePeriods = item.second; } } else if (item.first == "priority") { oboDirectConfig->priority = item.second; } else { warn_if_attribute_unexpected(item.first); } } Config->AddOboDirectConfig(_eventName, std::move(oboDirectConfig)); } ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEventAnnotations.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXEVENTANNOTATIONS_HH_ #define _CFGCTXEVENTANNOTATIONS_HH_ #include "CfgContext.hh" #include "CfgEventAnnotationType.hh" #include class CfgCtxEventAnnotations : public CfgContext { public: CfgCtxEventAnnotations(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEventAnnotations() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* Leave(); /// Set each annotation name and type. /// The itemname can be event name, source name, etc. void SetEventType(const std::string & itemname, EventAnnotationType::Type type); private: static subelementmap_t _subelements; static std::string _name; /// map key: itemname, value: annotation type std::unordered_map _eventmap; }; class CfgCtxEventAnnotation : public CfgContext { public: CfgCtxEventAnnotation(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEventAnnotation() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); void SetEventType(EventAnnotationType::Type type); void SetEventSasKey(std::string&& saskey); private: static subelementmap_t _subelements; static std::string _name; std::string _itemName; }; class CfgCtxEPA : public CfgContext { public: CfgCtxEPA(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEPA() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); void SetEventSasKey(std::string&& saskey); private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxEPAContent : public CfgContext { public: CfgCtxEPAContent(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEPAContent() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxEPAKey : public CfgContext { public: CfgCtxEPAKey(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEPAKey() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; std::string _decryptKeyPath; }; class CfgCtxOnBehalf : public CfgContext { public: CfgCtxOnBehalf(CfgContext* config, const std::string& eventName) : CfgContext(config), _eventName(eventName) {} virtual ~CfgCtxOnBehalf() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static std::string _name; static subelementmap_t _subelements; std::string _eventName; }; class CfgCtxOnBehalfContent : public CfgContext { public: CfgCtxOnBehalfContent(CfgContext* config, const std::string& eventName) : CfgContext(config), _eventName(eventName) {} virtual ~CfgCtxOnBehalfContent() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); private: static std::string _name; static subelementmap_t _subelements; std::string _eventName; }; class CfgCtxOnBehalfConfig : public CfgContext { public: CfgCtxOnBehalfConfig(CfgContext* config, const std::string& eventName) : CfgContext(config), _eventName(eventName) {} virtual ~CfgCtxOnBehalfConfig() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static std::string _name; static subelementmap_t _subelements; std::string _eventName; }; #endif // _CFGCTXEVENTANNOTATIONS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEvents.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxEvents.hh" #include "CfgCtxHeartBeats.hh" #include "CfgCtxOMI.hh" #include "CfgCtxMdsdEvents.hh" #include "CfgCtxDerived.hh" #include "CfgCtxExtensions.hh" #include "CfgCtxEtw.hh" ////////////////// CfgCtxEvents subelementmap_t CfgCtxEvents::_subelements = { { "HeartBeats", [](CfgContext* parent) -> CfgContext* { return new CfgCtxHeartBeats(parent); } }, { "OMI", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOMI(parent); } }, { "MdsdEvents", [](CfgContext* parent) -> CfgContext* { return new CfgCtxMdsdEvents(parent); } }, { "DerivedEvents", [](CfgContext* parent) -> CfgContext* { return new CfgCtxDerived(parent); } }, { "Extensions", [](CfgContext* parent) -> CfgContext* { return new CfgCtxExtensions(parent); } }, { "EtwProviders", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEtwProviders(parent); } } }; std::string CfgCtxEvents::_name = "Events"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxEvents.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXEVENTS_HH_ #define _CFGCTXEVENTS_HH_ #include "CfgContext.hh" class CfgCtxEvents : public CfgContext { public: CfgCtxEvents(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEvents() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXEVENTS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxExtensions.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxExtensions.hh" #include "Utility.hh" #include "MdsdExtension.hh" #include "MdsdConfig.hh" #include "CmdLineConverter.hh" ////////////////// CfgCtxExtensions subelementmap_t CfgCtxExtensions::_subelements = { { "Extension", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxExtension(parent); } } }; std::string CfgCtxExtensions::_name = "Extensions"; ////////////////// CfgCtxExtension void CfgCtxExtension::Enter(const xmlattr_t& properties) { const std::string extNameAttr = "extensionName"; for (const auto & item : properties) { if (extNameAttr == item.first) { _extension_name = item.second; } else { WARNING("Ignoring unexpected attribute " + item.second); } } if (_extension_name.empty()) { ERROR("<" + _name + "> requires attribute '" + extNameAttr + "'"); } _extension = new MdsdExtension(_extension_name); } CfgContext* CfgCtxExtension::Leave() { if (_extension) { Config->AddExtension(_extension); } else { ERROR("Unexpected NULL value for MdsdExtension object in CfgCtxExtension."); } return ParentContext; } subelementmap_t CfgCtxExtension::_subelements = { { "CommandLine", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxExtCmdLine(parent); } }, { "Body", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxExtBody(parent); } }, { "AlternativeExtensionLocation", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxExtAlterLocation(parent); } }, { "ResourceUsage", [] (CfgContext* parent) -> CfgContext* { return new CfgCtxExtResourceUsage(parent); } } }; std::string CfgCtxExtension::_name = "Extension"; ////////////////// CfgCtxExtCmdLine CfgContext* CfgCtxExtCmdLine::Leave() { std::string cmdline = std::move(Body); if (MdsdUtil::IsEmptyOrWhiteSpace(cmdline)) { ERROR("unexpected empty or whitespace value for Extension CmdLine"); } else { CfgCtxExtension * ctxext = dynamic_cast(ParentContext); if (ctxext) { CmdLineConverter::Tokenize(cmdline, std::bind(&CfgContext::WARNING, this, std::placeholders::_1)); // To warn (if any) sooner than later ctxext->GetExtension()->SetCmdLine(cmdline); } else { FATAL("Found <" + _name + "> in <" + ParentContext->Name() + ">; that can't happen"); } } return ParentContext; } subelementmap_t CfgCtxExtCmdLine::_subelements; std::string CfgCtxExtCmdLine::_name = "CommandLine"; ////////////////// CfgCtxExtBody CfgContext* CfgCtxExtBody::Leave() { if (empty_or_whitespace()) { WARNING("<" + _name + "> expected non-empty body; did not expect '{" + Body + "}'"); } else { CfgCtxExtension * ctxext = dynamic_cast(ParentContext); if (ctxext) { ctxext->GetExtension()->SetBody(Body); } else { FATAL("Found <" + _name + "> in <" + ParentContext->Name() + ">; that can't happen"); } } return ParentContext; } subelementmap_t CfgCtxExtBody::_subelements; std::string CfgCtxExtBody::_name = "Body"; ////////////////// CfgCtxExtAlterLocation CfgContext* CfgCtxExtAlterLocation::Leave() { std::string loc = std::move(Body); if (MdsdUtil::IsEmptyOrWhiteSpace(loc)) { WARNING("<" + _name + "> value cannot be empty or whitespace."); } else { CfgCtxExtension * ctxext = dynamic_cast(ParentContext); if (ctxext) { ctxext->GetExtension()->SetAlterLocation(loc); } else { FATAL("Found <" + _name + "> in <" + ParentContext->Name() + ">; that can't happen"); } } return ParentContext; } subelementmap_t CfgCtxExtAlterLocation::_subelements; std::string CfgCtxExtAlterLocation::_name = "AlternativeExtensionLocation"; ////////////////// CfgCtxExtResourceUsage void CfgCtxExtResourceUsage::Enter(const xmlattr_t& properties) { CfgCtxExtension * ctxext = dynamic_cast(ParentContext); if (!ctxext) { FATAL("Found <" + _name + "> in <" + ParentContext->Name() + ">; that can't happen"); return; } MdsdExtension * ext = ctxext->GetExtension(); for (const auto & item : properties) { if ("cpuPercentUsage" == item.first) { float f = std::stof(item.second); ext->SetCpuPercentUsage(f); } else if ("cpuThrottling" == item.first) { bool b = MdsdUtil::to_bool(item.second); ext->SetIsCpuThrottling(b); } else if ("memoryLimitInMB" == item.first) { unsigned long long m = std::stoull(item.second); ext->SetMemoryLimitInMB(m); } else if ("memoryThrottling" == item.first) { bool b = MdsdUtil::to_bool(item.second); ext->SetIsMemoryThrottling(b); } else if ("ioReadLimitInKBPerSecond" == item.first) { unsigned long long n = std::stoull(item.second); ext->SetIOReadLimitInKBPerSecond(n); } else if ("ioReadThrottling" == item.first) { bool b = MdsdUtil::to_bool(item.second); ext->SetIsIOReadThrottling(b); } else if ("ioWriteLimitInKBPerSecond" == item.first) { unsigned long long n = std::stoull(item.second); ext->SetIOWriteLimitInKBPerSecond(n); } else if ("ioWriteThrottling" == item.first) { bool b = MdsdUtil::to_bool(item.second); ext->SetIsIOWriteThrottling(b); } else { WARNING("Ignoring unexpected attribute " + item.second); } } } subelementmap_t CfgCtxExtResourceUsage::_subelements; std::string CfgCtxExtResourceUsage::_name = "ResourceUsage"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxExtensions.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXEXTENSIONS_HH_ #define _CFGCTXEXTENSIONS_HH_ #include "CfgContext.hh" #include "CfgCtxError.hh" #include class MdsdExtension; /// /// Extensions define all the monitoring agent's extensions. /// class CfgCtxExtensions : public CfgContext { public: CfgCtxExtensions(CfgContext *config) : CfgContext(config) {} virtual ~CfgCtxExtensions() { } virtual const std::string& Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; /// /// Extension specifies one monitoring agent extension. The Name and CommandLine of an /// extension are required. Other properperties are optional. /// class CfgCtxExtension : public CfgContext { public: CfgCtxExtension(CfgContext * config) : CfgContext(config), _extension(nullptr) {} virtual ~CfgCtxExtension() { } virtual const std::string& Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return (_extension_name.empty())? (CfgCtxError::subelements) : (_subelements); } void Enter(const xmlattr_t& properties); CfgContext* Leave(); MdsdExtension * GetExtension() const { return _extension; } private: static subelementmap_t _subelements; static std::string _name; std::string _extension_name; MdsdExtension * _extension; }; /// /// This specifies an extension's command line. It is required. /// class CfgCtxExtCmdLine : public CfgContext { public: CfgCtxExtCmdLine(CfgContext * config) : CfgContext(config) {} virtual ~CfgCtxExtCmdLine() { } virtual const std::string & Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; }; /// /// Body: optional XML element. It specifies an extension's config body to be passed to the /// extension via environment variable "MON_EXTENSION_BODY". /// class CfgCtxExtBody : public CfgContext { public: CfgCtxExtBody(CfgContext * config) : CfgContext(config) {} virtual ~CfgCtxExtBody() { } virtual const std::string & Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; }; /// /// This specifies the extension home directory. It is optional. /// class CfgCtxExtAlterLocation : public CfgContext { public: CfgCtxExtAlterLocation(CfgContext * config) : CfgContext(config) {} virtual ~CfgCtxExtAlterLocation() { } virtual const std::string & Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; }; /// /// This specifies the limits of CPU, memory, IO throttling information. They will overwrite /// the default values defined in Management\AgentResourceUsage\ExtensionResourceUsage. /// class CfgCtxExtResourceUsage : public CfgContext { public: CfgCtxExtResourceUsage(CfgContext * config) : CfgContext(config) { } virtual ~CfgCtxExtResourceUsage() { } virtual const std::string & Name() const { return _name; } static const std::string& XmlName() { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif // _CFGCTXEXTENSIONS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxHeartBeats.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxHeartBeats.hh" ////////////////// CfgCtxHeartBeats subelementmap_t CfgCtxHeartBeats::_subelements = { { "HeartBeat", [](CfgContext* parent) -> CfgContext* { return new CfgCtxHeartBeat(parent); } } }; std::string CfgCtxHeartBeats::_name = "HeartBeats"; ////////////////// CfgCtxHeartBeat subelementmap_t CfgCtxHeartBeat::_subelements; std::string CfgCtxHeartBeat::_name = "HeartBeat"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxHeartBeats.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #include "CfgContext.hh" class CfgCtxHeartBeats : public CfgContext { public: CfgCtxHeartBeats(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxHeartBeats() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxHeartBeat : public CfgContext { public: CfgCtxHeartBeat(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxHeartBeat() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { log_entry(properties); } private: static subelementmap_t _subelements; static std::string _name; }; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxImports.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxImports.hh" #include "MdsdConfig.hh" ////////////// CfgCtxImports subelementmap_t CfgCtxImports::_subelements = { { "Import", [](CfgContext* parent) -> CfgContext* { return new CfgCtxImport(parent); } } }; std::string CfgCtxImports::_name = "Imports"; ////////////// CfgCtxImport subelementmap_t CfgCtxImport::_subelements; std::string CfgCtxImport::_name = "Import"; void CfgCtxImport::Enter(const xmlattr_t& properties) { std::string filename; // Find the file attribute; invoke Config->LoadFromConfigFile() on the value thereof. for (const auto& item : properties) { if (item.first == "file") { filename = item.second; } else { Config->AddMessage(MdsdConfig::warning, "Ignoring unknown attribute \"" + item.first + "\""); } } if (filename.empty()) { Config->AddMessage(MdsdConfig::error, ": \"file\" attribute is missing or empty"); } else { Config->LoadFromConfigFile(filename); } } ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxImports.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXIMPORTS_HH_ #define _CFGCTXIMPORTS_HH_ #include "CfgContext.hh" class CfgCtxImports : public CfgContext { public: CfgCtxImports(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxImports() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxImport : public CfgContext { public: CfgCtxImport(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxImport() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXIMPORTS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxManagement.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxManagement.hh" #include "MdsdConfig.hh" #include "Listener.hh" #include "Utility.hh" #include "Trace.hh" #include /////// CfgCtxManagement subelementmap_t CfgCtxManagement::_subelements = { { "Identity", [](CfgContext* parent) -> CfgContext* { return new CfgCtxIdentity(parent); } }, { "AgentResourceUsage", [](CfgContext* parent) -> CfgContext* { return new CfgCtxAgentResourceUsage(parent); } }, { "OboDirectPartitionField", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOboDirectPartitionField(parent); } } }; std::string CfgCtxManagement::_name = "Management"; std::map CfgCtxManagement::_eventVolumes = { { "Small", 1 }, { "small", 1 }, { "Medium", 10 }, { "medium", 10 }, { "Large", 100 }, { "large", 100 } }; void CfgCtxManagement::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxManagement::Enter"); for (const auto& item : properties) { if (item.first == "eventVolume") { auto numPart = _eventVolumes.find(item.second); if (numPart != _eventVolumes.end()) { Config->PartitionCount(numPart->second); } else { ERROR("Unknown eventVolume \"" + item.second + "\""); } } else if (item.first == "defaultRetentionInDays") { unsigned long retention = std::stoul(item.second); if (retention < 1) { ERROR("Invalid value for defaultRetentionInDays"); } else { Config->DefaultRetention(retention); } } else { Config->AddMessage(MdsdConfig::warning, " ignoring unexpected attribute " + item.first); } } } ////// CfgCtxIdentity subelementmap_t CfgCtxIdentity::_subelements = { { "IdentityComponent", [](CfgContext* parent) -> CfgContext* { return new CfgCtxIdentityComponent(parent); } } }; std::string CfgCtxIdentity::_name = "Identity"; void CfgCtxIdentity::Enter(const xmlattr_t& properties) { Config->SetTenantAlias("Tenant"); Config->SetRoleAlias("Role"); Config->SetRoleInstanceAlias("RoleInstance"); for (const auto& item : properties) { if (item.first == "type") { if (item.second == "TenantRole") { // Add three identity components based on envariables AddEnvariable("Tenant", "MONITORING_TENANT"); AddEnvariable("Role", "MONITORING_ROLE"); AddEnvariable("RoleInstance", "MONITORING_ROLE_INSTANCE"); IdentityWasSet = true; } else if (item.second == "ComputerName") { // Add a single identity component containing the hostname (void)Config->AddIdentityColumn("ComputerName", Config->AgentIdentity()); IdentityWasSet = true; } else { WARNING("Ignoring unknown type " + item.second); } } else if (item.first == "tenantNameAlias") { Config->SetTenantAlias(item.second); } else if (item.first == "roleNameAlias") { Config->SetRoleAlias(item.second); } else if (item.first == "roleInstanceNameAlias") { Config->SetRoleInstanceAlias(item.second); } else { WARNING("Ignoring unknown attribute " + item.first); } } } void CfgCtxIdentity::AddString(const std::string& name, const std::string& value) { if (IdentityWasSet) { WARNING("Ignoring extra identity column " + name); return; } if (!(Config->AddIdentityColumn(name, value))) { ERROR("Duplicate IdentityComponent " + name); } } void CfgCtxIdentity::AddEnvariable(const std::string& name, const std::string& varname) { if (IdentityWasSet) { WARNING("Ignoring extra identity column " + name); return; } try { std::string Value = MdsdUtil::GetEnvironmentVariable(varname); if (!(Config->AddIdentityColumn(name, Value))) { ERROR("Duplicate IdentityComponent " + name); } } catch (std::exception & ex) { WARNING(std::string(ex.what()) + "; " + name + " not added to identity columns"); } } ////// CfgCtxIdentityComponent subelementmap_t CfgCtxIdentityComponent::_subelements; std::string CfgCtxIdentityComponent::_name = "IdentityComponent"; void CfgCtxIdentityComponent::Enter(const xmlattr_t& properties) { IsValid = true; // Assume this will be a valid definition IgnoreBody = ExtraBody = false; std::string Envariable; bool useHostname = false; _ctxidentity = dynamic_cast(ParentContext); if (!_ctxidentity) { FATAL("Found in <" + ParentContext->Name() + ">; that can't happen"); IsValid = false; return; } for (const auto& item : properties) { if (item.first == "name") { ComponentName = item.second; } else if (item.first == "envariable") { Envariable = item.second; } else if (item.first == "useComputerName") { useHostname = MdsdUtil::to_bool(item.second); } else { ERROR(" ignoring unexpected attribute " + item.first); } } if (ComponentName.empty()) { ERROR(" requires attribute \"name\""); IsValid = false; } else if (!Envariable.empty() && useHostname) { ERROR("Cannot specify both useComputerName and envariable for the same "); IsValid = false; } else if (!Envariable.empty() || useHostname) { IgnoreBody = true; if (useHostname) { _ctxidentity->AddString(ComponentName, Config->AgentIdentity()); } else { _ctxidentity->AddEnvariable(ComponentName, Envariable); } } // If !IgnoreBody && IsValid, then the Leave() method will add the accumulated // string to the Identity column set if (!IsValid) { IgnoreBody = true; } } void CfgCtxIdentityComponent::HandleBody(const std::string& body) { if (IgnoreBody) { ExtraBody = true; // We'll ignore it and warn about it } else { Body += body; } } CfgContext* CfgCtxIdentityComponent::Leave() { if (!IsValid) { WARNING("Skipping invalid IdentityComponent"); } else if (ExtraBody) { WARNING("Ignoring extra content for IdentityComponent; hope that's okay"); } else if (!IgnoreBody) { if (empty_or_whitespace()) { WARNING("Empty value for IdentityComponent; hope that's okay"); } _ctxidentity->AddString(ComponentName, Body); } return ParentContext; } ////// CfgCtxAgentResourceUsage subelementmap_t CfgCtxAgentResourceUsage::_subelements; std::string CfgCtxAgentResourceUsage::_name = "AgentResourceUsage"; void CfgCtxAgentResourceUsage::Enter(const xmlattr_t& properties) { for (const auto& item : properties) { if (item.first == "diskQuotaInMB") { unsigned long diskQuota = std::stoul(item.second); if (diskQuota < 1) { ERROR("diskQuotaInMB must be greater than zero"); } else { Config->AddQuota("disk", diskQuota); } } else if (item.first == "dupeWindowSeconds") { unsigned long dupeWindow = std::stoul(item.second); if (dupeWindow < 60) { WARNING("dupeWindowSeconds must be >= 60"); dupeWindow = 60; } else if (dupeWindow > 3600) { WARNING("dupeWindowSeconds must be <= 3600"); dupeWindow = 3600; } Listener::setDupeWindow(dupeWindow); } else { ERROR(" ignoring unexpected attribute " + item.first); } } } ////// CfgCtxOboDirectPartitionField subelementmap_t CfgCtxOboDirectPartitionField::_subelements; std::string CfgCtxOboDirectPartitionField::_name = "OboDirectPartitionField"; void CfgCtxOboDirectPartitionField::Enter(const xmlattr_t& properties) { std::string name, value; for (const auto& item : properties) { if (item.first == "name") { name = item.second; } else if (item.first == "value") { value = item.second; } else { WARNING("Ignoring unknown attribute " + item.first); } } if (name.empty() || value.empty()) { ERROR(" requires both 'name' and 'value' attributes."); return; } Config->SetOboDirectPartitionFieldNameValue(std::move(name), std::move(value)); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxManagement.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXMANAGEMENT_HH_ #define _CFGCTXMANAGEMENT_HH_ #include "CfgContext.hh" #include class TableSchema; class CfgCtxManagement : public CfgContext { public: CfgCtxManagement(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxManagement() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; static std::map _eventVolumes; }; class CfgCtxIdentity : public CfgContext { public: CfgCtxIdentity(CfgContext* config) : CfgContext(config), IdentityWasSet(false) {} virtual ~CfgCtxIdentity() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); void AddString(const std::string& n, const std::string& str); void AddEnvariable(const std::string& n, const std::string& varname); private: bool IdentityWasSet; static subelementmap_t _subelements; static std::string _name; }; class CfgCtxIdentityComponent : public CfgContext { public: CfgCtxIdentityComponent(CfgContext* config) : CfgContext(config), _ctxidentity(nullptr) {} virtual ~CfgCtxIdentityComponent() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); virtual void HandleBody(const std::string& body); CfgContext* Leave(); private: std::string ComponentName; bool IsValid; //bool GotBody; bool ExtraBody; bool IgnoreBody; CfgCtxIdentity* _ctxidentity; static subelementmap_t _subelements; static std::string _name; }; class CfgCtxAgentResourceUsage : public CfgContext { public: CfgCtxAgentResourceUsage(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxAgentResourceUsage() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxOboDirectPartitionField : public CfgContext { public: CfgCtxOboDirectPartitionField(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxOboDirectPartitionField() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXMANAGEMENT_HH_ // :vim set ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxMdsdEvents.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxMdsdEvents.hh" #include "MdsdConfig.hh" #include "MdsEntityName.hh" #include "Subscription.hh" #include "Utility.hh" #include "Priority.hh" #include "PipeStages.hh" #include "LocalSink.hh" #include #include "CfgCtxParser.hh" #include "EventType.hh" ////////////////// CfgCtxMdsdEvents subelementmap_t CfgCtxMdsdEvents::_subelements = { { "MdsdEventSource", [](CfgContext* parent) -> CfgContext* { return new CfgCtxMdsdEventSource(parent); } } }; std::string CfgCtxMdsdEvents::_name = "MdsdEvents"; ////////////////// CfgCtxMdsdEventSource void CfgCtxMdsdEventSource::Enter(const xmlattr_t& properties) { for (const auto& item : properties) { if (item.first == "source") { _source = item.second; } else { WARNING("Ignoring unexpected attribute " + item.first); } } if (_source.empty()) { ERROR("Missing required source attribute"); return; } if (!Config->IsValidSource(_source) && !Config->IsValidDynamicSchemaSource(_source)) { ERROR("Undefined source \"" + _source + "\""); _source.clear(); // Puts the entire element in error state } else { // The LocalSink object should be already created _sink = LocalSink::Lookup(_source); if (!_sink) { ERROR("Failed to find LocalSink for MdsdEventSource \"" + _source + "\""); } } } subelementmap_t CfgCtxMdsdEventSource::_subelements = { { "RouteEvent", [](CfgContext* parent) -> CfgContext* { return new CfgCtxRouteEvent(parent); } } }; std::string CfgCtxMdsdEventSource::_name = "MdsdEventSource"; ////////////////// CfgCtxRouteEvent // Construct a Subscription object to query the event sink. Build the front of the pipeline to // process entities fetched from the sink. // The duration attribute is optional. If it's not set, a duration based on priority is used. // If priority is not explicitly set, there's a default for that, which then governs the duration. void CfgCtxRouteEvent::Enter(const xmlattr_t& properties) { _subscription = 0; _storeType = StoreType::XTable; _doSchemaGeneration = true; bool addIdentity = true; _ctxEventSource = dynamic_cast(ParentContext); if (!_ctxEventSource) { FATAL("Found in <" + ParentContext->Name() + ">; that can't happen"); return; } CfgCtx::CfgCtxParser parser(this); if (!parser.ParseEvent(properties, CfgCtx::EventType::RouteEvent)) { return; } std::string eventName = parser.GetEventName(); Priority priority = parser.GetPriority(); std::string account = parser.GetAccount(); bool NoPerNDay = parser.IsNoPerNDay(); time_t interval = parser.GetInterval(); if (parser.HasStoreType()) { _storeType = parser.GetStoreType(); _doSchemaGeneration = StoreType::DoSchemaGeneration(_storeType); addIdentity = StoreType::DoAddIdentityColumns(_storeType); } try { // Build target on the stack, move it into the Subscription task auto target = MdsEntityName { eventName, NoPerNDay, Config, account, _storeType }; assert(interval != 0); _subscription = new Subscription( _ctxEventSource->Sink(), std::move(target), priority, MdsTime(interval) ); if (addIdentity) { // When we add custom identity columns per-subscription, sub them in here _subscription->AddStage(new Pipe::Identity(Config->GetIdentityVector())); } Config->AddMonikerEventInfo(account, eventName, _storeType, _ctxEventSource->Source(), mdsd::EventType::RouteEvent); } catch (const std::invalid_argument& ex) { ERROR(ex.what()); return; } catch (...) { FATAL("Unknown exception; skipping"); return; } } CfgContext* CfgCtxRouteEvent::Leave() { if (! _subscription) { return ParentContext; } // Non-local/file targets need to have a schema constructed and pushed. The schema for // events from a given external source is fixed, so it only needs to be computed once // and pushed once per Nday period if (_doSchemaGeneration) { _subscription->AddStage(new Pipe::BuildSchema(Config, _subscription->target(), true)); } // Find/make the batch for this task; add a final pipeline stage to write to that batch; // add the subscription to the config. Batch *batch = Config->GetBatch(_subscription->target(), _subscription->Duration()); if (batch) { _subscription->AddStage(new Pipe::BatchWriter(batch, Config->GetIdentityVector(), Config->PartitionCount(), _storeType)); // Config->AddSubscription(_ctxEventSource->Source(), _subscription); Config->AddTask(_subscription); } else { ERROR("Unable to create routing for this event"); } return ParentContext; } subelementmap_t CfgCtxRouteEvent::_subelements = { { "Filter", [](CfgContext* parent) -> CfgContext* { return new CfgCtxFilter(parent); } } }; std::string CfgCtxRouteEvent::_name = "RouteEvent"; ////////////////// CfgCtxFilter subelementmap_t CfgCtxFilter::_subelements; std::string CfgCtxFilter::_name = "Filter"; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxMdsdEvents.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXMDSDEVENTS_HH_ #define _CFGCTXMDSDEVENTS_HH_ #include "CfgContext.hh" #include "CfgCtxError.hh" #include #include "Subscription.hh" class LocalSink; class CfgCtxMdsdEvents : public CfgContext { public: CfgCtxMdsdEvents(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxMdsdEvents() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxMdsdEventSource : public CfgContext { public: CfgCtxMdsdEventSource(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxMdsdEventSource() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return (_source.empty())?(CfgCtxError::subelements):(_subelements); } void Enter(const xmlattr_t& properties); const std::string& Source() { return _source; } LocalSink * Sink() { return _sink; } private: static subelementmap_t _subelements; static std::string _name; std::string _source; LocalSink *_sink; }; class CfgCtxRouteEvent : public CfgContext { public: CfgCtxRouteEvent(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxRouteEvent() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; Subscription* _subscription; StoreType::Type _storeType; CfgCtxMdsdEventSource* _ctxEventSource; bool _doSchemaGeneration; }; class CfgCtxFilter : public CfgContext { public: CfgCtxFilter(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxFilter() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { log_entry(properties); } private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXMDSDEVENTS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxMonMgmt.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxMonMgmt.hh" #include "CfgCtxImports.hh" #include "CfgCtxAccounts.hh" #include "CfgCtxManagement.hh" #include "CfgCtxSchemas.hh" #include "CfgCtxEnvelope.hh" #include "CfgCtxSources.hh" #include "CfgCtxEvents.hh" #include "CfgCtxSvcBusAccts.hh" #include "CfgCtxEventAnnotations.hh" #include "MdsdConfig.hh" #include "Trace.hh" subelementmap_t CfgCtxMonMgmt::_subelements = { { "Imports", [](CfgContext* parent) -> CfgContext* { return new CfgCtxImports(parent); } }, { "Accounts", [](CfgContext* parent) -> CfgContext* { return new CfgCtxAccounts(parent); } }, { "Management", [](CfgContext* parent) -> CfgContext* { return new CfgCtxManagement(parent); } }, { "Schemas", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSchemas(parent); } }, { "EnvelopeSchema", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEnvelope(parent); } }, { "Sources", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSources(parent); } }, { "Events", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEvents(parent); } }, { "ServiceBusAccountInfos", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSvcBusAccts(parent); } }, { "EventStreamingAnnotations", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEventAnnotations(parent); } } }; std::string CfgCtxMonMgmt::_name = "MonitoringManagement"; void CfgCtxMonMgmt::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxMonMgmt::Enter"); if (Config->MonitoringManagementSeen()) { return; } bool versionChecked = false; for (const auto& item : properties) { if (item.first == "namespace") { Config->Namespace(item.second); } else if (item.first == "eventVersion") { int ver = std::stoi(item.second); if (ver > 0) { Config->EventVersion(ver); } else { Config->AddMessage(MdsdConfig::error, "eventVersion, when present, must be a positive integer"); } } else if (item.first == "version") { versionChecked = true; if (item.second != "1.0") { Config->AddMessage(MdsdConfig::fatal, "Only config file version 1.0 is supported"); } } else if (item.first == "timestamp") { Config->Timestamp(item.second); } else { Config->AddMessage(MdsdConfig::warning, " ignoring unexpected attribute \"" + item.first + "\""); } } if (!versionChecked) { Config->AddMessage(MdsdConfig::fatal, "Must specify \"version\" attribute"); } Config->MonitoringManagementSeen(true); } CfgContext* CfgCtxMonMgmt::Leave() { Config->ValidateEvents(); return ParentContext; } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxMonMgmt.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXMONMGMT_HH_ #define _CFGCTXMONMGMT_HH_ #include "CfgContext.hh" class CfgCtxMonMgmt : public CfgContext { public: CfgCtxMonMgmt(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxMonMgmt() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXMONMGMT_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxOMI.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxOMI.hh" #include "CfgCtxError.hh" #include "OmiTask.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include "StoreType.hh" #include "PipeStages.hh" #include "Trace.hh" #include "EventType.hh" #include #include ////////////////// CfgCtxOMI subelementmap_t CfgCtxOMI::_subelements = { { "OMIQuery", [](CfgContext* parent) -> CfgContext* { return new CfgCtxOMIQuery(parent); } } }; std::string CfgCtxOMI::_name = "OMI"; ////////////////// CfgCtxOMIQuery void CfgCtxOMIQuery::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxOMIQuery::Enter"); std::string eventName, account, omiNamespace, cqlQuery; Priority priority; time_t sampleRate = 0; bool NoPerNDay = false; _task = nullptr; _isOK = true; _storeType = StoreType::XTable; _doSchemaGeneration = true; for (const auto& item : properties) { if (item.first == "eventName") { if (MdsdUtil::NotValidName(item.second)) { ERROR("Invalid eventName attribute"); } else { eventName = item.second; } } else if (item.first == "priority") { if (! priority.Set(item.second)) { WARNING("Ignoring unknown priority \"" + item.second + "\""); } } else if (item.first == "account") { if (MdsdUtil::NotValidName(item.second)) { ERROR("Invalid account attribute"); } else { account = item.second; } } else if (item.first == "dontUsePerNDayTable") { NoPerNDay = MdsdUtil::to_bool(item.second); } else if (item.first == "omiNamespace") { omiNamespace = item.second; } else if (item.first == "cqlQuery") { cqlQuery = MdsdUtil::UnquoteXmlAttribute(item.second); } else if (item.first == "sampleRateInSeconds") { time_t requestedRate = std::stoul(item.second); if (requestedRate == 0) { ERROR("Invalid sampleRateInSeconds attribute - using default"); } else { sampleRate = requestedRate; } } else if (item.first == "storeType") { _storeType = StoreType::from_string(item.second); _doSchemaGeneration = StoreType::DoSchemaGeneration(_storeType); } else { WARNING("Ignoring unexpected attribute " + item.first); } } try { // Build target on the stack, move it into the OmiTask auto target = MdsEntityName { eventName, NoPerNDay, Config, account, _storeType }; _task = new OmiTask(Config, std::move(target), priority, omiNamespace, cqlQuery, sampleRate); // Centrally-stored events implicitly have Identity columns added to them as // defined in the element. Add them first thing so they're available // to subsequent stages (if any). if (_storeType != StoreType::Local) { _task->AddStage(new Pipe::Identity(Config->GetIdentityVector())); } Config->AddMonikerEventInfo(account, eventName, _storeType, "", mdsd::EventType::OMIQuery); } catch (const std::invalid_argument& ex) { ERROR(ex.what()); _isOK = false; return; } catch (...) { FATAL("Unknown exception; skipping"); _isOK = false; return; } } CfgContext* CfgCtxOMIQuery::Leave() { Trace trace(Trace::ConfigLoad, "CfgCtxOMIQuery::Leave"); if(_task) { // If not local/file, add a stage to push metadata into MDS. OMI queries should produce results with // the same schema each time. Doing an doesn't change that. if (_doSchemaGeneration) { _task->AddStage(new Pipe::BuildSchema(Config, _task->Target(), true)); } // Find/make the batch for this task; add a final pipeline stage to write to that batch; // add the task to the set of tasks in this config. Batch *batch = Config->GetBatch(_task->Target(), _task->FlushInterval()); if (batch) { _task->AddStage(new Pipe::BatchWriter(batch, Config->GetIdentityVector(), Config->PartitionCount(), _storeType)); Config->AddOmiTask(_task); } else { ERROR("Configuration error(s) detected; dropping this OMIQuery."); delete _task; } } return ParentContext; } const subelementmap_t& CfgCtxOMIQuery::GetSubelementMap() const { if (_isOK) { return _subelements; } else { return CfgCtxError::subelements; } } subelementmap_t CfgCtxOMIQuery::_subelements { { "Unpivot", [](CfgContext* parent) -> CfgContext* { return new CfgCtxUnpivot(parent); } } }; std::string CfgCtxOMIQuery::_name = "OMIQuery"; ////////////////// CfgCtxUnpivot void CfgCtxUnpivot::Enter(const xmlattr_t& properties) { _query = dynamic_cast(ParentContext); if (!_query) { ERROR(" is not a valid subelement of <" + ParentContext->Name() + ">"); _isOK = false; return; } // Bail if parent didn't parse right or didn't build an OmiTask instance if (! (_query->isOK() && _query->GetTask())) { _isOK = false; return; } for (const auto &iter : properties) { if (iter.first == "columnValue") { _valueAttrName = iter.second; } else if (iter.first == "columnName") { _nameAttrName = iter.second; } else if (iter.first == "columns") { _unpivotColumns = iter.second; } else { WARNING("Ignoring unexpected attribute " + iter.first); } } if (_valueAttrName.empty() || _nameAttrName.empty() || _unpivotColumns.empty()) { ERROR("Missing one or more required attributes (columnValue, columnName, columns)"); _isOK = false; return; } } CfgContext* CfgCtxUnpivot::Leave() { if (_isOK) { auto unpivoter = new Pipe::Unpivot(_valueAttrName, _nameAttrName, _unpivotColumns, std::move(_transforms)); _query->GetTask()->AddStage(unpivoter); } return ParentContext; } void CfgCtxUnpivot::addTransform(const std::string& from, const std::string& to, double scale) { _transforms.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, scale)); } subelementmap_t CfgCtxUnpivot::_subelements { { "MapName", [](CfgContext* parent) -> CfgContext* { return new CfgCtxMapName(parent); } } }; std::string CfgCtxUnpivot::_name = "Unpivot"; ////////////////// CfgCtxMapName void CfgCtxMapName::Enter(const xmlattr_t& properties) { _unpivot = dynamic_cast(ParentContext); if (!_unpivot) { ERROR(" is not a valid subelement of <" + ParentContext->Name() + ">"); _isOK = false; return; } for (const auto &iter : properties) { if (iter.first == "name") { _from = iter.second; } else if (iter.first == "scaleUp") { _scale *= std::stod(iter.second); } else if (iter.first == "scaleDown") { _scale /= std::stod(iter.second); } else { WARNING("Ignoring unexpected attribute " + iter.first); } } if (_from.empty()) { ERROR("Missing required \"from\" attribute"); _isOK = false; return; } } // Process XML body; accumulate it as the value of the _to instance var void CfgCtxMapName::HandleBody(const std::string& body) { if (_isOK) { _to += body; } } // Now that we have the target name for the translation, let's save it. CfgContext* CfgCtxMapName::Leave() { if (_isOK) { if (_to.empty()) { _to = _from; } _unpivot->addTransform(_from, _to, _scale); } else { ERROR("Error(s) detected; ignoring this element"); } return ParentContext; } subelementmap_t CfgCtxMapName::_subelements; std::string CfgCtxMapName::_name = "MapName"; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxOMI.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXOMI_HH_ #define _CFGCTXOMI_HH_ #include "CfgContext.hh" #include "StoreType.hh" #include #include "PipeStages.hh" class OmiTask; class CfgCtxOMI : public CfgContext { public: CfgCtxOMI(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxOMI() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxOMIQuery : public CfgContext { public: CfgCtxOMIQuery(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxOMIQuery() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const; void Enter(const xmlattr_t& properties); CfgContext* Leave(); OmiTask * GetTask() const { return _task; } bool isOK() const { return _isOK; } private: static subelementmap_t _subelements; static std::string _name; OmiTask *_task; bool _isOK; StoreType::Type _storeType; bool _doSchemaGeneration; }; class CfgCtxUnpivot : public CfgContext { public: CfgCtxUnpivot(CfgContext* config) : CfgContext(config), _isOK(true) {} virtual ~CfgCtxUnpivot() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); CfgContext* Leave(); void addTransform(const std::string& from, const std::string& to, double scale = 1.0 ); private: static subelementmap_t _subelements; static std::string _name; CfgCtxOMIQuery* _query; bool _isOK; std::string _valueAttrName; std::string _nameAttrName; std::string _unpivotColumns; std::unordered_map _transforms; }; class CfgCtxMapName : public CfgContext { public: CfgCtxMapName(CfgContext* config) : CfgContext(config), _isOK(true), _scale(1.0) {} virtual ~CfgCtxMapName() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); void HandleBody(const std::string& body); CfgContext* Leave(); private: static subelementmap_t _subelements; static std::string _name; bool _isOK; CfgCtxUnpivot* _unpivot; std::string _from; std::string _to; double _scale; }; #endif //_CFGCTXOMI_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxParser.hh" #include "Utility.hh" #include "MdsTime.hh" using namespace CfgCtx; std::map CfgCtxParser::s_evtParsers = { { "account", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseName(iter->first, iter->second, p->m_account); } }, { "dontUsePerNDayTable", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { p->m_isNoPerNDay = MdsdUtil::to_bool(iter->second); return true; } }, { "duration", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseDuration(iter->first, iter->second); } }, { "eventName", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseName(iter->first, iter->second, p->m_eventName); } }, { "priority", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParsePriority(iter->first, iter->second); } }, { "storeType", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseStoreType(iter->first, iter->second); } } }; std::map CfgCtxParser::s_etwEvtParsers = BuildEtwParsersTable(); std::map CfgCtxParser::BuildEtwParsersTable() { std::map tmp = s_evtParsers; tmp["id"] = [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseId(iter->first, iter->second); }; return tmp; } std::map CfgCtxParser::s_etwProviderParsers = { { "format", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseName(iter->first, iter->second, p->m_format); } }, { "guid", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseName(iter->first, iter->second, p->m_guid); } }, { "priority", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParsePriority(iter->first, iter->second); } }, { "storeType", [] (CfgCtxParser* p, const xmlattr_iter_t & iter) -> bool { return p->ParseStoreType(iter->first, iter->second); } } }; bool CfgCtxParser::ParseEvent( const xmlattr_t& properties, EventType eventType ) { if (!m_context) { return false; } bool resultOK = true; auto & parsersTable = (EventType::RouteEvent == eventType) ? s_evtParsers : s_etwEvtParsers; for (xmlattr_iter_t iter = properties.begin(); iter != properties.end(); ++iter) { auto parserIter = parsersTable.find(iter->first); if (parserIter != parsersTable.end()) { resultOK = resultOK && parserIter->second(this, iter); } else { LogUnexpectedAttrNameWarn(iter->first); } } // validate required attributes if (m_eventName.empty()) { LogRequiredAttrError("eventName"); resultOK = false; } if (EventType::EtwEvent == eventType && m_eventId < 0) { LogRequiredAttrError("id"); resultOK = false; } return resultOK; } bool CfgCtxParser::ParseEtwProvider( const xmlattr_t& properties ) { if (!m_context) { return false; } bool resultOK = true; const char* supportedFormat = "EventSource"; // only this is supported for now for (xmlattr_iter_t iter = properties.begin(); iter != properties.end(); ++iter) { auto parserIter = s_etwProviderParsers.find(iter->first); if (parserIter != s_etwProviderParsers.end()) { resultOK = resultOK && parserIter->second(this, iter); } else { LogUnexpectedAttrNameWarn(iter->first); } } if (m_guid.empty()) { LogRequiredAttrError("guid"); resultOK = false; } if (!m_format.empty() && supportedFormat != m_format) { LogInvalidValueError("format", m_format); resultOK = false; } return resultOK; } bool CfgCtxParser::ParseName( const std::string & attrName, const std::string & attrValue, std::string & result) { result = attrValue; if (MdsdUtil::NotValidName(result)) { result.clear(); LogInvalidValueError(attrName, attrValue); return false; } return true; } bool CfgCtxParser::ParseStoreType(const std::string & attrName, const std::string & attrValue) { bool resultOK = true; m_storeType = StoreType::from_string(attrValue); if (StoreType::None == m_storeType) { LogInvalidValueError(attrName, attrValue); resultOK = false; } else { m_hasStoreType = true; } return resultOK; } bool CfgCtxParser::ParsePriority(const std::string & attrName, const std::string & attrValue) { m_hasPriority = true; if (!m_priority.Set(attrValue)) { LogUnknownAttrValueWarn(attrName, attrValue); m_hasPriority = false; } else if (0 == m_interval) { m_interval = m_priority.Duration(); } return true; } bool CfgCtxParser::ParseDuration(const std::string & attrName, const std::string & attrValue) { m_interval = MdsTime::FromIS8601Duration(attrValue).to_time_t(); if (0 == m_interval) { LogInvalidValueError(attrName, attrValue); return false; } else if (10 > m_interval) { m_context->WARNING("Minimum supported duration is ten (10) seconds; using minimum"); m_interval = 10; } return true; } bool CfgCtxParser::ParseId(const std::string & attrName, const std::string & attrValue) { int tmp = atoi(attrValue.c_str()); if (tmp < 0 || tmp > INT_MAX) { LogInvalidValueError(attrName, attrValue); return false; } else { m_eventId = tmp; } return true; } ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXPARSER_HH_ #define _CFGCTXPARSER_HH_ #include "ConfigParser.hh" #include #include #include "StoreType.hh" #include "Priority.hh" #include "CfgContext.hh" extern "C" { #include } namespace CfgCtx { /// mdsd Event types enum class EventType { RouteEvent, EtwEvent }; /// A utility class to parse mdsd configuration XML. /// It implements parsing routines for common XML properties like /// priority, storeType, etc. class CfgCtxParser { using typeparser_t = std::function; public: /// /// Create a new parser instance. /// /// Context where the parser is called. CfgCtxParser(CfgContext * context) : m_context(context) { } ~CfgCtxParser() {} /// /// Parse properties of an EventType. After parsing, the results will /// be available from GetXXX() functions. /// /// properties to parse /// EventType /// Return true if no error, false if any error. bool ParseEvent(const xmlattr_t& properties, EventType eventType); /// /// Parse XML configuration. /// /// properties to parse /// Return true if no error, false if any error. bool ParseEtwProvider(const xmlattr_t& properties); std::string GetAccount() const { return m_account; } bool IsNoPerNDay() const { return m_isNoPerNDay; } time_t GetInterval() const { return (0 == m_interval)? m_priority.Duration() : m_interval; } std::string GetEventName() const { return m_eventName; } std::string GetFormat() const { return m_format; } std::string GetGuid() const { return m_guid; } int GetEventId() const { return m_eventId; } Priority GetPriority() const { return m_priority; } bool HasPriority() const { return m_hasPriority; } StoreType::Type GetStoreType() const { return m_storeType; } bool HasStoreType() const { return m_hasStoreType; } private: bool ParseName(const std::string & attrName, const std::string & attrValue, std::string & result); bool ParsePriority(const std::string & attrName, const std::string & attrValue); bool ParseStoreType(const std::string & attrName, const std::string & attrValue); bool ParseDuration(const std::string & attrName, const std::string & attrValue); bool ParseId(const std::string & attrName, const std::string & attrValue); void LogInvalidValueError(const std::string & attrName, const std::string & attrValue) { m_context->ERROR("<" + m_context->Name() + "> attribute '" + attrName + "' has invalid value '" + attrValue + "'."); } void LogUnknownAttrValueWarn(const std::string & attrName, const std::string & attrValue) { m_context->WARNING("<" + m_context->Name() + ">: ignoring unknown '" + attrName + "'' value '" + attrValue + "'"); } void LogRequiredAttrError(const std::string & attrName) { m_context->ERROR("<" + m_context->Name() + "> requires attribute '" + attrName + "'"); } void LogUnexpectedAttrNameWarn(const std::string & attrName) { m_context->WARNING("<" + m_context->Name() + "> ignoring unexpected attribute '" + attrName + "'."); } static std::map BuildEtwParsersTable(); private: CfgContext * const m_context; std::string m_account; bool m_isNoPerNDay = false; time_t m_interval = 0; std::string m_eventName; std::string m_format; std::string m_guid; int m_eventId = -1; Priority m_priority; bool m_hasPriority = false; StoreType::Type m_storeType = StoreType::None; bool m_hasStoreType = false; static std::map s_evtParsers; static std::map s_etwEvtParsers; static std::map s_etwProviderParsers; }; } // namespace CfgCtx #endif // _CFGCTXPARSER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxRoot.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxRoot.hh" #include "CfgCtxMonMgmt.hh" subelementmap_t CfgCtxRoot::_subelements = { { "MonitoringManagement", [](CfgContext* parent) -> CfgContext* { return new CfgCtxMonMgmt(parent); } } }; std::string CfgCtxRoot::_name = "(Document Root)"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxRoot.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXROOT_HH_ #define _CFGCTXROOT_HH_ #include "CfgContext.hh" #include "MdsdConfig.hh" #include #include class CfgCtxRoot : public CfgContext { public: /// /// The root context for a document. Tracks no information from prior context. Is neither entered /// nor left, in the sense of document parsing. CfgCtxRoot(MdsdConfig* config) : CfgContext(config) {} virtual ~CfgCtxRoot() {}; virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { log_entry(properties); } private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXROOT_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSchemas.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxSchemas.hh" #include "MdsdConfig.hh" #include "Engine.hh" #include "TableSchema.hh" #include subelementmap_t CfgCtxSchemas::_subelements = { { "Schema", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSchema(parent); } } }; std::string CfgCtxSchemas::_name = "Schemas"; /////////////////////// subelementmap_t CfgCtxSchema::_subelements = { { "Column", [](CfgContext* parent) -> CfgContext* { return new CfgCtxColumn(parent); } } }; std::string CfgCtxSchema::_name = "Schema"; const subelementmap_t& CfgCtxSchema::GetSubelementMap() const { if (_schema) { return _subelements; } else { return CfgCtxError::subelements; } } void CfgCtxSchema::Enter(const xmlattr_t& properties) { _schema = 0; for (const auto& item : properties) { if (item.first == "name") { if (_schema == 0) { _schema = new TableSchema(item.second); } else { Config->AddMessage(MdsdConfig::error, "\"name\" can appear in only once"); } } else { Config->AddMessage(MdsdConfig::warning, "Ignoring unexpected attribute \"" + item.first + "\""); } } if (_schema == 0) { Config->AddMessage(MdsdConfig::fatal, " requires \"name\" attribute"); } } // Called from CfgCtxColumn::Enter() void CfgCtxSchema::AddColumn(const std::string& n, const std::string& srctype, const std::string& mdstype) { // If we have no valid schema, or we've seen the column before, skip it. if (!_schema) return; auto result = _schema->AddColumn(n, srctype, mdstype); if (!result) { return; } std::ostringstream msg; switch (result) { case TableSchema::Ok: return; // !!! Return, not break case TableSchema::NoConverter: msg << "Can't convert " << srctype << " to " << mdstype << " - ignoring column " << n; msg << ". Known converters: " << Engine::ListConverters(); break; case TableSchema::DupeColumn: msg << "Column " << n << " already added to Schema " << _schema->Name(); delete _schema; _schema = 0; // Throw away the schema, we're broken break; case TableSchema::BadSrcType: msg << "Unknown source type " << srctype << " - ignoring column " << n; msg << ". Known converters: " << Engine::ListConverters(); break; case TableSchema::BadMdsType: msg << "Unknown MDS type " << mdstype << " - ignoring column " << n; msg << ". Known converters: " << Engine::ListConverters(); break; } Config->AddMessage(MdsdConfig::error, msg.str()); } CfgContext* CfgCtxSchema::Leave() { if (_schema) { Config->AddSchema(_schema); // All the way through without a fatal error - add it to the config } else { Config->AddMessage(MdsdConfig::error, "Schema dropped from active configuration due to errors"); } return ParentContext; } /////////////////////// subelementmap_t CfgCtxColumn::_subelements; std::string CfgCtxColumn::_name = "Column"; void CfgCtxColumn::Enter(const xmlattr_t& properties) { std::string colname; std::string srctype, mdstype; for (const auto& item : properties) { if (item.first == "name") { colname = item.second; } else if (item.first == "type") { srctype = item.second; } else if (item.first == "mdstype") { mdstype = item.second; } else { Config->AddMessage(MdsdConfig::warning, "Ignoring unexpected attribute \"" + item.first + "\""); } } if (colname.empty() || srctype.empty() || mdstype.empty()) { Config->AddMessage(MdsdConfig::error, "Missing required attributes (name, type, mdstype)"); } else { CfgCtxSchema* ctxschema = dynamic_cast(ParentContext); if (ctxschema) { ctxschema->AddColumn(colname, srctype, mdstype); } else { Config->AddMessage(MdsdConfig::fatal, "Found in <" + ParentContext->Name() + ">; that can't happen"); } } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSchemas.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXSCHEMAS_HH_ #define _CFGCTXSCHEMAS_HH_ #include "CfgContext.hh" #include "CfgCtxError.hh" #include class TableSchema; class CfgCtxSchemas : public CfgContext { public: CfgCtxSchemas(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSchemas() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxSchema : public CfgContext { public: CfgCtxSchema(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSchema() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const; void Enter(const xmlattr_t& properties); CfgContext* Leave(); void AddColumn(const std::string& n, const std::string& srctype, const std::string& mdstype); private: TableSchema* _schema; std::set _columnNames; static subelementmap_t _subelements; static std::string _name; }; class CfgCtxColumn : public CfgContext { public: CfgCtxColumn(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxColumn() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXSCHEMAS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSources.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxSources.hh" #include "MdsdConfig.hh" #include "LocalSink.hh" #include "EventType.hh" #include "Utility.hh" subelementmap_t CfgCtxSources::_subelements = { { "Source", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSource(parent); } } }; std::string CfgCtxSources::_name = "Sources"; //////////// void CfgCtxSource::Enter(const xmlattr_t& properties) { std::string name, schema, dynamic_schema; for (const auto& item : properties) { if (item.first == "name") { name = item.second; } else if (item.first == "schema") { schema = item.second; } else if (item.first == "dynamic_schema") { dynamic_schema = item.second; } else { Config->AddMessage(MdsdConfig::warning, "Ignoring unexpected attribute \"" + item.first + "\""); } } auto isOK = true; if (name.empty()) { Config->AddMessage(MdsdConfig::fatal, " requires a \"name\" attribute"); isOK = false; } auto isDynamicSchema = MdsdUtil::to_bool(dynamic_schema); if ((!schema.empty() && isDynamicSchema) || (schema.empty() && (dynamic_schema.empty() || !isDynamicSchema))) { Config->AddMessage(MdsdConfig::fatal, " requires either a valid \"schema\" attribute or that the \"dynamic_schema\" attribute be set to \"true\", but not both."); } if (!isOK) { return; } auto sink = LocalSink::Lookup(name); if (!sink) { sink = new LocalSink(name); } if (!isDynamicSchema) { Config->AddSource(name, schema); sink->AllocateSchemaId(); } else { Config->AddDynamicSchemaSource(name); } Config->AddMonikerEventInfo("", "", StoreType::Local, name, mdsd::EventType::None); } subelementmap_t CfgCtxSource::_subelements; std::string CfgCtxSource::_name = "Source"; ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSources.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXSOURCES_HH_ #define _CFGCTXSOURCES_HH_ #include "CfgContext.hh" class CfgCtxSources : public CfgContext { public: CfgCtxSources(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSources() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxSource : public CfgContext { public: CfgCtxSource(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSource() { } virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif //_CFGCTXSOURCES_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSvcBusAccts.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CfgCtxSvcBusAccts.hh" #include "MdsdConfig.hh" #include "EventPubCfg.hh" #include "Trace.hh" #include "Utility.hh" #include "cryptutil.hh" ///////// CfgCtxSvcBusAccts subelementmap_t CfgCtxSvcBusAccts::_subelements = { { "ServiceBusAccountInfo", [](CfgContext* parent) -> CfgContext* { return new CfgCtxSvcBusAcct(parent); } } }; std::string CfgCtxSvcBusAccts::_name = "ServiceBusAccountInfos"; ///////// CfgCtxSvcBusAcct subelementmap_t CfgCtxSvcBusAcct::_subelements = { { "EventPublisher", [](CfgContext* parent) -> CfgContext* { return new CfgCtxEventPublisher(parent); } } }; std::string CfgCtxSvcBusAcct::_name = "ServiceBusAccountInfo"; void CfgCtxSvcBusAcct::Enter(const xmlattr_t& properties) { Trace trace(Trace::ConfigLoad, "CfgCtxSvcBusAcct::Enter"); const std::string & attrMoniker = "name"; for (const auto& item : properties) { if (attrMoniker == item.first) { parse_singleton_attribute(item.first, item.second, attrMoniker, _moniker); } else { warn_if_attribute_unexpected(item.first); } } fatal_if_no_attributes(attrMoniker, _moniker); } ///////// CfgCtxEventPublisher subelementmap_t CfgCtxEventPublisher::_subelements; std::string CfgCtxEventPublisher::_name = "EventPublisher"; void CfgCtxEventPublisher::Enter(const xmlattr_t& properties) { std::string valConnStr; std::string valDecryptKeyPath; const std::string & attrConnStr = "connectionString"; const std::string & attrDecryptKeyPath = "decryptKeyPath"; for (const auto & item : properties) { if (attrConnStr == item.first) { parse_singleton_attribute(item.first, item.second, attrConnStr, valConnStr); } else if (attrDecryptKeyPath == item.first) { parse_singleton_attribute(item.first, item.second, attrDecryptKeyPath, valDecryptKeyPath); } else { warn_if_attribute_unexpected(item.first); } } fatal_if_no_attributes(attrConnStr, valConnStr); auto sbObj = dynamic_cast(ParentContext); if (!sbObj) { fatal_if_impossible_subelement(); return; } auto sbmoniker = sbObj->GetMoniker(); try { if (valDecryptKeyPath.empty()) { auto escapedConnStr = MdsdUtil::UnquoteXmlAttribute(valConnStr); Config->GetEventPubCfg()->AddServiceBusAccount(sbmoniker, std::move(escapedConnStr)); } else { if (!MdsdUtil::IsRegFileExists(valDecryptKeyPath)) { ERROR("Cannot find decrypt key path " + valDecryptKeyPath); } else { auto decryptedSas = cryptutil::DecodeAndDecryptString(valDecryptKeyPath, std::move(valConnStr)); Config->GetEventPubCfg()->AddServiceBusAccount(sbmoniker, std::move(decryptedSas)); } } } catch(const std::exception & ex) { ERROR("<" + Name() + "> exception: " + ex.what()); } } ================================================ FILE: Diagnostic/mdsd/mdsd/CfgCtxSvcBusAccts.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGCTXSVCBUSACCTS_HH_ #define _CFGCTXSVCBUSACCTS_HH_ #include "CfgContext.hh" class CfgCtxSvcBusAccts : public CfgContext { public: CfgCtxSvcBusAccts(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSvcBusAccts() {} virtual const std::string & Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties) { warn_if_attributes(properties); } private: static subelementmap_t _subelements; static std::string _name; }; class CfgCtxSvcBusAcct : public CfgContext { public: CfgCtxSvcBusAcct(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxSvcBusAcct() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); std::string GetMoniker() const { return _moniker; } private: static subelementmap_t _subelements; static std::string _name; std::string _moniker; }; class CfgCtxEventPublisher : public CfgContext { public: CfgCtxEventPublisher(CfgContext* config) : CfgContext(config) {} virtual ~CfgCtxEventPublisher() {} virtual const std::string& Name() const { return _name; } virtual const subelementmap_t& GetSubelementMap() const { return _subelements; } void Enter(const xmlattr_t& properties); private: static subelementmap_t _subelements; static std::string _name; }; #endif // _CFGCTXSVCBUSACCTS_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgEventAnnotationType.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGEVENTANNOTATIONTYPE_HH_ #define _CFGEVENTANNOTATIONTYPE_HH_ namespace EventAnnotationType { // Because one event can be multiple types, // each type should be a power of 2. enum Type { None = 0, EventPublisher = 1 << 0, OnBehalf = 1 << 1 }; }; #endif // _CFGEVENTANNOTATIONTYPE_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CfgOboDirectConfig.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CFGOBODIRECTCONFIG_HH_ #define _CFGOBODIRECTCONFIG_HH_ #include // struct to hold OBO direct upload config data namespace mdsd { struct OboDirectConfig { // Currently all fields are as is from the XML CDATA config (e.g., "ProviderName,AnsiString"). // Parse out as desired. std::string onBehalfFields; std::string containerSuffix; std::string primaryPartitionField; std::string partitionFields; std::string onBehalfReplaceFields; std::string excludeFields; std::string timePeriods = "PT1H"; // timePeriods is optional and "PT1H" by default if not given. std::string priority; }; } #endif // _CFGOBODIRECTCONFIG_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/CmdLineConverter.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "CmdLineConverter.hh" #include "Utility.hh" #include "Trace.hh" #include #include #include std::vector CmdLineConverter::Tokenize(const std::string& cmdline, std::function ctxLogOnWarning) { Trace trace(Trace::Extensions, "CmdLineConverter::Tokenize"); auto current = cmdline.begin(); size_t pos = 1; std::vector argv; std::string element; enum TokenizerState { outside, within, escape, singlequote, doublequote, doubleescape }; TokenizerState state = outside; while (current != cmdline.end()) { // Generally, state transitions consume the character that causes the transition. (See bottom of loop.) // Any exceptions to this rule are clearly noted (by using "continue"). switch (state) { case outside: // Advance past whitespace, else transition to state=within switch (*current) { case ' ': case '\n': break; default: state = within; // NOTE: This state transition does NOT consume the character continue; } break; case within: switch (*current) { case '\\': // escape character - change to matching state state = escape; break; case '\'': // start single quote - change to matching state state = singlequote; break; case '"': // start double quote - change to matching state state = doublequote; break; case ' ': // whitespace terminates the element, which we can push case '\n': // into the vector; change to "outside" state argv.emplace_back(std::move(element)); element.clear(); state = outside; break; default: element.push_back(*current); break; } break; case escape: // Only blank, newline, backslash, singlequote, and doublequote can be escaped; if the // character isn't one of those, put the backslash into the element along with the // shouldn't-have-been-escaped character. if (std::string(" \n\\'\"").find_first_of(*current) == std::string::npos) { element.push_back('\\'); } element.push_back(*current); state = within; break; case singlequote: if (*current != '\'') { element.push_back(*current); } else { state = within; } break; case doublequote: switch (*current) { case '"': state = within; break; case '\\': state = doubleescape; break; default: element.push_back(*current); break; } break; case doubleescape: // If it's not a backslash or a doublequote, it can't be escaped, so flow the escape char through if (std::string("\\\"").find_first_of(*current) == std::string::npos) { element.push_back('\\'); } element.push_back(*current); state = doublequote; break; } current++; pos++; } std::string warnMsg; switch (state) { case outside: break; case within: if (element.size()) { argv.emplace_back(std::move(element)); } break; case singlequote: case doublequote: // Issue config-file parsing warning about an unterminated quote at the end of a cmdline warnMsg = "Unterminated quote at the end of the command line"; trace.NOTEWARN(warnMsg); ctxLogOnWarning(warnMsg); // Auto-close it and add it, even it if's an empty string argv.emplace_back(std::move(element)); break; case escape: case doubleescape: // Issue config-file warning about incomplete escape at the end of the cmdline warnMsg = "Incomplete escape at the end of the command line"; trace.NOTEWARN(warnMsg); ctxLogOnWarning(warnMsg); // Add what we have argv.emplace_back(std::move(element)); break; } return argv; } CmdLineConverter::CmdLineConverter(const std::string & cmdline) { Trace trace(Trace::Extensions, "CmdLineConverter::CmdLineConverter"); try { std::vector strarray = Tokenize(cmdline); execvp_nargs = strarray.size(); execvp_args = new char*[execvp_nargs+1]; size_t i = 0; for (const auto& x : strarray) { size_t len = x.length(); execvp_args[i] = static_cast(malloc(len+1)); strncpy(execvp_args[i], x.c_str(), len); execvp_args[i][len] = '\0'; i++; } execvp_args[execvp_nargs] = NULL; } catch (const std::exception& e) { trace.NOTEERR("Failed to parse cmdline: '" + cmdline + "'. Error=" + e.what()); } } CmdLineConverter::~CmdLineConverter() { if (execvp_args) { for (size_t i = 0; i < execvp_nargs; i++) { free(execvp_args[i]); execvp_args[i] = NULL; } delete [] execvp_args; execvp_args = NULL; } } ================================================ FILE: Diagnostic/mdsd/mdsd/CmdLineConverter.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CMDLINECONVERTER_HH_ #define _CMDLINECONVERTER_HH_ #include #include #include #include "CfgContext.hh" class CmdLineConverter { public: CmdLineConverter(const std::string & cmdline); virtual ~CmdLineConverter(); static std::vector Tokenize(const std::string& cmdline, std::function ctxLogOnWarning = [](const std::string&){} // Don't do any warning logging by default ); /// /// Returns the char* array that can be used for execvp() directly. /// The caller shouldn't free the memory from this function. /// NOTE: the last item of the array is always NULL. /// char** argv() const { return execvp_args; } /// /// Returns the number of items in execvp args. This doesn't include /// the last NULL element. /// size_t argc() const { return execvp_nargs; } private: size_t execvp_nargs = 0; char** execvp_args = NULL; }; #endif // _CMDLINECONVERTER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ConfigParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ConfigParser.hh" ConfigParser::~ConfigParser() { } void ConfigParser::OnStartElement(const std::string& name, const xmlattr_t& properties) { currentContext = currentContext->SubContextFactory(name); currentContext->Enter(properties); } void ConfigParser::OnEndElement(const std::string&) { CfgContext* tmp = currentContext; currentContext = currentContext->Leave(); delete tmp; } ================================================ FILE: Diagnostic/mdsd/mdsd/ConfigParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CONFIGPARSER_HH_ #define _CONFIGPARSER_HH_ #include "SaxParserBase.hh" #include "CfgContext.hh" #include "MdsdConfig.hh" #include class ConfigParser : public SaxParserBase { public: /// /// Initialize a parser to handle a config file. /// /// A CfgContext class whose factory can construct contexts for the expected root element /// The MdsdConfig to which this parse should log any warnings or errors ConfigParser(CfgContext* Root, MdsdConfig* Config) : currentContext(Root), config(Config) {}; virtual ~ConfigParser(); private: CfgContext* currentContext; MdsdConfig* const config; protected: virtual void OnStartDocument() override {}; virtual void OnEndDocument() override {}; virtual void OnComment(const std::string&) override {}; virtual void OnStartElement(const std::string& name, const xmlattr_t& properties) override; virtual void OnCharacters(const std::string& characters) override { currentContext->HandleBody(characters); }; virtual void OnEndElement(const std::string& name) override; virtual void OnWarning(const std::string& text) override { config->AddMessage(MdsdConfig::warning, text); } virtual void OnError(const std::string& text) override { config->AddMessage(MdsdConfig::error, text); } virtual void OnFatalError(const std::string& text) override { config->AddMessage(MdsdConfig::fatal, text); } virtual void OnCDataBlock(const std::string& text) override { currentContext->HandleCdata(text); } }; #endif //_CONFIGPARSER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/Constants.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Constants.hh" #include #include #define DEFINE_STRING(name, value) const std::string name { value }; const std::wstring name ## W { L ## value }; uint64_t Constants::_UniqueId { 0 }; namespace Constants { const std::string TIMESTAMP { "TIMESTAMP" }; const std::string PreciseTimeStamp { "PreciseTimeStamp" }; namespace Compression { const std::string lz4hc { "lz4hc" }; } // namespace Compression namespace EventCategory { const std::string Counter { "counter" }; const std::string Trace { "trace" }; } // namespace EventCategory namespace AzurePropertyNames { DEFINE_STRING(Namespace, "namespace") DEFINE_STRING(EventName, "eventname") DEFINE_STRING(EventVersion, "eventversion") DEFINE_STRING(EventCategory, "eventcategory") DEFINE_STRING(BlobVersion, "version") DEFINE_STRING(BlobFormat, "format") DEFINE_STRING(DataSize, "datasizeinbytes") DEFINE_STRING(BlobSize, "blobsizeinbytes") DEFINE_STRING(MonAgentVersion, "monagentversion") DEFINE_STRING(CompressionType, "compressiontype") DEFINE_STRING(MinLevel, "minlevel") DEFINE_STRING(AccountMoniker, "accountmoniker") DEFINE_STRING(Endpoint, "endpoint") DEFINE_STRING(OnbehalfFields, "onbehalffields") DEFINE_STRING(OnbehalfServiceId, "onbehalfid") DEFINE_STRING(OnbehalfAnnotations, "onbehalfannotations") } // namespace AzurePropertyNames uint64_t UniqueId() { static std::string digits { "0123456789ABCDEFabcdef" }; if (!Constants::_UniqueId) { std::ifstream bootid("/proc/sys/kernel/random/boot_id", std::ifstream::in); if (bootid.is_open()) { uint64_t id = 0; int nybbles = 16; while (nybbles && bootid.good()) { char c = bootid.get(); size_t pos = digits.find(c); if (pos != std::string::npos) { if (pos > 15) { pos -= 6; } id <<= 4; id += pos; nybbles--; } } if (id == 0) { id = 1; // Backstop in case something got weird } Constants::_UniqueId = id; } else { Constants::_UniqueId = 1; // Backstop in case something got weird } } return Constants::_UniqueId; } } // namespace Constants // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Constants.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _CONSTANTS_HH_ #define _CONSTANTS_HH_ #pragma once #define DECLARE_STRING(name) extern const std::string name; extern const std::wstring name ## W; #include namespace Constants { extern const std::string TIMESTAMP; extern const std::string PreciseTimeStamp; enum class ETWlevel : unsigned char { LogAlways = 0, Critical = 1, Error = 2, Warning = 3, Information = 4, Verbose = 5 }; static constexpr uint32_t MDS_blob_version { 1 }; static constexpr uint32_t MDS_blob_format { 2 }; extern uint64_t _UniqueId; uint64_t UniqueId(); namespace Compression { extern const std::string lz4hc; } namespace EventCategory { extern const std::string Counter; extern const std::string Trace; } // namespace EventCategory namespace AzurePropertyNames { DECLARE_STRING(Namespace) DECLARE_STRING(EventName) DECLARE_STRING(EventVersion) DECLARE_STRING(EventCategory) DECLARE_STRING(BlobVersion) DECLARE_STRING(BlobFormat) DECLARE_STRING(DataSize) DECLARE_STRING(BlobSize) DECLARE_STRING(MonAgentVersion) DECLARE_STRING(CompressionType) DECLARE_STRING(MinLevel) DECLARE_STRING(AccountMoniker) DECLARE_STRING(Endpoint) DECLARE_STRING(OnbehalfFields) DECLARE_STRING(OnbehalfServiceId) DECLARE_STRING(OnbehalfAnnotations) } // namespace AzurePropertyNames }; #undef DECLARE_STRING #endif // _CONSTANTS_HH_ // vim: se sw=8 ================================================ FILE: Diagnostic/mdsd/mdsd/Credentials.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Credentials.hh" #include #include #include "Trace.hh" #include "Logger.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include "AzureUtility.hh" using std::string; std::ostream& operator<<(std::ostream &os, const Credentials& creds) { os << &creds << "=(Moniker " << creds.Moniker() << " type " << creds.TypeName() << ")"; return os; } const std::string& Credentials::ServiceType_to_string(ServiceType svcType) { static std::map stmap = { { Credentials::ServiceType::XTable, "XTable" }, { Credentials::ServiceType::Blob, "Blob" }, { Credentials::ServiceType::EventPublish, "EventPublish" } }; static std::string UnknownType { "Unknown ServiceType" }; auto iter = stmap.find(svcType); if (iter == stmap.end()) { return UnknownType; } else { return iter->second; } } std::ostream& operator<<(std::ostream &os, Credentials::ServiceType svcType) { os << Credentials::ServiceType_to_string(svcType); return os; } // Extract the "se" part of the query string and expire 30-60 minutes before then MdsTime CredentialType::AutoKey::GetExpireTimeFromSasSE(const std::string & sas) { std::map qry; MdsdUtil::ParseQueryString(sas, qry); auto exp = qry.find("se"); if (exp == qry.end()) { // Shouldn't happen, but if it does, the URI should be good for 11-12 hours. return (MdsTime::Now() + MdsTime(11 * 3600 + random()%3600)); } else { return (MdsTime(exp->second) - MdsTime(1800 + random()%1800)); } } // Three output parameters are set by the ConnectionString() methods // // For XTable, fullSvcName will be set to the actual XStore table name to be used. The namespace prefix and the version // and perNDay suffixes will be applied as appropriate. The perNDay selected is "right now". // connstr will be set to the connection string. // expires will be set to the expiration time of the connection string (i.e. the time at which a new // connection string should be requested). // // Returns true if a connection string could be constructed; false if not. bool CredentialType::Local::ConnectionString(const MdsEntityName &target, ServiceType svcType, string &fullSvcName, string &connstr, MdsTime &expires) const { Trace trace(Trace::Credentials, "ConnectionString Local"); Logger::LogError("Can't make connection string for Local moniker " + Moniker()); return false; } bool CredentialType::SharedKey::ConnectionString(const MdsEntityName &target, ServiceType svcType, string &fullSvcName, string &connstr, MdsTime &expires) const { Trace trace(Trace::Credentials, "ConnectionString SharedKey"); try { connstr = GetConnectionStringOnly(svcType); } catch (std::invalid_argument& e) { trace.NOTE(e.what()); Logger::LogError(e.what()); return false; } fullSvcName = target.Name(); if (target.IsConstant()) { expires = MdsTime::Max(); } else { // Rebuild connection string at next ten-day interval expires = (MdsTime::Now() + 10*24*3600).RoundTenDay(); } return true; } std::string CredentialType::SharedKey::GetConnectionStringOnly(ServiceType svcType) const { std::ostringstream conn; if (ServiceType::Blob == svcType) { conn << "BlobEndpoint=" << _blobUri; } else if (ServiceType::XTable == svcType) { conn << "TableEndpoint=" << _tableUri; } else { throw invalid_type(svcType); } conn << ";AccountName=" << _accountName << ";AccountKey=" << _secret; return conn.str(); } bool CredentialType::AutoKey::ConnectionString( const MdsEntityName &target, ServiceType svcType, string &fullSvcName, string &connstr, MdsTime &expires) const { Trace trace(Trace::Credentials, "ConnectionString AutoKey"); std::ostringstream conn; string autokey; switch (svcType) { case ServiceType::EventPublish: fullSvcName = target.EventName(); autokey = _config->GetEventPublishCmdXmlItems(Moniker(), fullSvcName).sas; break; case ServiceType::Blob: case ServiceType::XTable: fullSvcName = target.Name(); autokey = _config->GetAutokey(Moniker(), fullSvcName); break; default: std::ostringstream strm; strm << "Error: AutoKey credential doesn't support service " << svcType; trace.NOTE(strm.str()); Logger::LogError(strm.str()); return false; } if (autokey.empty()) { std::ostringstream strm; strm << "Can't find autokey for moniker " << Moniker() << ", " << svcType << " " << fullSvcName; trace.NOTE(strm.str()); Logger::LogError(strm.str()); return false; } string endpointName; string endpointSep; if (ServiceType::XTable == svcType) { endpointName = "TableEndpoint"; endpointSep = "/$batch?"; } else if (ServiceType::Blob == svcType) { endpointName = "BlobEndpoint"; endpointSep = "/" + fullSvcName + "?"; } size_t pos = autokey.find(endpointSep); // Separates endpoint from SAS if (pos == string::npos) { std::ostringstream msg; msg << "Improperly formatted autokey for " << Moniker() << ", " << svcType << " " << fullSvcName; msg << ": \"" << autokey << "\""; trace.NOTE(msg.str()); Logger::LogError(msg.str()); return false; } conn << endpointName << "=" << autokey.substr(0, pos); conn << ";SharedAccessSignature=" << autokey.substr(pos+endpointSep.size()); if (!autokey.empty()) { expires = GetExpireTimeFromSasSE(autokey); } // If the tablename can change, rebuild at the change time, if that's sooner if (!target.IsConstant()) { MdsTime proposed = (MdsTime::Now() + 10*24*3600).RoundTenDay(); if (proposed < expires) { expires = proposed; } } connstr = conn.str(); trace.NOTE("AutoKey ConnectionString='" + connstr + "'."); return true; } mdsd::EhCmdXmlItems CredentialType::AutoKey::GetEhParameters(const std::string& eventName, Credentials::ServiceType eventType ) const { if (Credentials::ServiceType::EventPublish == eventType) { return _config->GetEventPublishCmdXmlItems(Moniker(), eventName); } throw invalid_type(eventType); } CredentialType::SAS::SAS(const std::string& moniker, const std::string& acct, const std::string &token) : Credentials(moniker, SecretType::SAS), _secret(token), _accountName(acct), _blobUri(MakePublicCloudEndpoint(acct, ServiceType::Blob)), _tableUri(MakePublicCloudEndpoint(acct, ServiceType::XTable)) { MdsdUtil::ValidateSAS(token, _isAccountSas); } std::string CredentialType::SAS::GetConnectionStringOnly(ServiceType svcType) const { std::ostringstream conn; if (ServiceType::XTable == svcType) { conn << "TableEndpoint=" << _tableUri; } else if (ServiceType::Blob == svcType) { conn << "BlobEndpoint=" << _blobUri; } else { throw invalid_type(svcType); } conn << ";SharedAccessSignature=" << _secret; return conn.str(); } bool CredentialType::SAS::ConnectionString(const MdsEntityName &target, ServiceType svcType, string &fullSvcName, string &connstr, MdsTime &expires) const { Trace trace(Trace::Credentials, "ConnectionString SAS"); try { connstr = GetConnectionStringOnly(svcType); } catch (std::invalid_argument& e) { trace.NOTE(e.what()); Logger::LogError(e.what()); return false; } std::map qry; MdsdUtil::ParseQueryString(_secret, qry); if (IsAccountSas()) { // The SAS is an account SAS, replacing the storage shared key, and the svc name should be a name with // the 10-day suffix, not the base name. fullSvcName = target.Name(); } else if (ServiceType::XTable == svcType) { // SAS (non-account SAS) includes the tablename; update to match, otherwise the SAS won't work. auto item = qry.find("tn"); if (item != qry.end()) { fullSvcName = item->second; } else { Logger::LogError("SAS for MDS moniker " + Moniker() + " missing tn= component"); fullSvcName = target.Basename(); } } auto exp = qry.find("se"); if (exp == qry.end()) { expires = MdsTime::Max(); // No expiration in SAS } else { expires = MdsTime(exp->second); if (MdsTime::Now() > expires) { Logger::LogError("Expired SAS for MDS moniker " + Moniker()); } } if (IsAccountSas()) { // Set expires for next ten-day interval, following the storage shared key credential logic. // (Note: The account SAS itself will/should never expire, like a storage shared key) expires = std::min(expires, (MdsTime::Now() + 10*24*3600).RoundTenDay()); } return true; } // vim: se sw=4 expandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/Credentials.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _CREDENTIALS_HH_ #define _CREDENTIALS_HH_ #include #include #include "MdsTime.hh" #include "MdsEntityName.hh" #include "EventHubCmd.hh" class MdsdConfig; class Credentials { friend std::ostream& operator<<(std::ostream &os, const Credentials& creds); public: enum SecretType { None, Key, SAS }; // EventPublish: Event data directly publishing to EventHub. enum class ServiceType { XTable, Blob, EventPublish }; static const std::string& ServiceType_to_string(ServiceType svcType); Credentials(const std::string& moniker, SecretType type) : _moniker(moniker), _secretType(type) {} virtual ~Credentials() {} const std::string Moniker() const { return _moniker; } SecretType Type() const { return _secretType; } virtual bool useAutoKey() const { return false; } virtual std::string AccountName() const = 0; virtual bool ConnectionString(const MdsEntityName &target, ServiceType svcType, std::string &fullSvcName, std::string &connstr, MdsTime &expires) const = 0; virtual const std::string TypeName() const = 0; virtual bool accessAnyTable() const { return (Type() == Key || useAutoKey() ); } private: const std::string _moniker; SecretType _secretType; Credentials() = delete; }; std::ostream& operator<<(std::ostream &os, const Credentials& creds); std::ostream& operator<<(std::ostream &os, Credentials::ServiceType svcType); namespace CredentialType { class invalid_type : public std::logic_error { public: invalid_type(Credentials::ServiceType svcType) : std::logic_error("Service type [" + Credentials::ServiceType_to_string(svcType) + "] not supported by this operation") { } }; static inline std::string MakePublicCloudEndpoint(const std::string& acct, Credentials::ServiceType svcType) { std::string result; result.reserve(33 + acct.size()); result.append("https://").append(acct); if (svcType == Credentials::ServiceType::Blob) { result.append(".blob.core.windows.net"); } else if (svcType == Credentials::ServiceType::XTable) { result.append(".table.core.windows.net"); } else { throw invalid_type(svcType); } return result; } class SharedKey : public Credentials { public: SharedKey(const std::string& moniker, const std::string &name, const std::string &key) : Credentials(moniker, SecretType::Key), _accountName(name), _secret(key), _blobUri(MakePublicCloudEndpoint(name, ServiceType::Blob)), _tableUri(MakePublicCloudEndpoint(name, ServiceType::XTable)) {} std::string AccountName() const { return _accountName; } bool ConnectionString(const MdsEntityName &target, ServiceType svcType, std::string &fullSvcName, std::string &connstr, MdsTime &expires) const; const std::string TypeName() const { return std::string{"SharedKey"}; } void TableUri(const std::string& uri) { _tableUri = uri; } void BlobUri(const std::string& uri) { _blobUri = uri; } // To get the connection string only, without passing target. Will throw if svcType is neither blob nor table. std::string GetConnectionStringOnly(ServiceType svcType) const; private: std::string _accountName; std::string _secret; std::string _blobUri; std::string _tableUri; }; class AutoKey : public Credentials { public: AutoKey(const std::string& moniker, MdsdConfig *config) : Credentials(moniker, SecretType::SAS), _config(config) {} std::string AccountName() const { return std::string{"AutoKey"}; } bool ConnectionString(const MdsEntityName &target, ServiceType svcType, std::string &fullSvcName, std::string &connstr, MdsTime &expires) const; const std::string TypeName() const { return std::string{"AutoKey"}; } bool useAutoKey() const { return true; } static MdsTime GetExpireTimeFromSasSE(const std::string & sas); mdsd::EhCmdXmlItems GetEhParameters(const std::string& eventName, Credentials::ServiceType eventType) const; private: MdsdConfig *_config; }; class SAS : public Credentials { public: SAS(const std::string& moniker, const std::string& acct, const std::string &token); std::string AccountName() const { return _accountName; } bool ConnectionString(const MdsEntityName &target, ServiceType svcType, std::string &fullSvcName, std::string &connstr, MdsTime &expires) const; const std::string TypeName() const { return std::string{"SAS"}; } const std::string Token() const { return _secret; } bool IsAccountSas() const { return _isAccountSas; } bool accessAnyTable() const { return _isAccountSas; } void BlobUri(const std::string& uri) { _blobUri = uri; } void TableUri(const std::string& uri) { _tableUri = uri; } // To get the connection string only, without passing target. Will throw if svcType is neither blob nor table. std::string GetConnectionStringOnly(ServiceType svcType) const; private: std::string _secret; std::string _accountName; std::string _blobUri; std::string _tableUri; bool _isAccountSas; }; class Local : public Credentials { public: Local() : Credentials(std::string{"(LOCAL)"}, SecretType::None) {} std::string AccountName() const { return std::string{"Local"}; } bool ConnectionString(const MdsEntityName &target, ServiceType svcType, std::string &fullSvcName, std::string &connstr, MdsTime &expires) const; const std::string TypeName() const { return std::string{"Local"}; } }; } #endif // _CREDENTIALS_HH_ // vim: set ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/DaemonConf.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "DaemonConf.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" #include "Version.hh" #include #include extern "C" { #include #include #include #include #include #include #include #include } uid_t DaemonConf::GetUidFromName(const char* username) { Trace trace(Trace::Daemon, "GetUidFromName"); uid_t uid = 0; if (!username) { Logger::LogError("Error: GetUidFromName(): unexpected NULL pointer for username."); return uid; } struct passwd *resultObj; struct passwd wrkObj; char buf[2048]; getpwnam_r(username, &wrkObj, buf, sizeof(buf), &resultObj); if (resultObj == NULL) { Logger::LogWarn("WARN: GetUidFromName(): No user called '" + std::string(username) + "' is found."); } else { uid = resultObj->pw_uid; } trace.NOTE("Name='" + std::string(username) + "'. UID=" + std::to_string(uid)); return uid; } gid_t DaemonConf::GetGidFromName(const char* groupname) { Trace trace(Trace::Daemon, "GetGidFromName"); gid_t gid = 0; if (!groupname) { Logger::LogError("Error: GetGidFromName(): unexpected NULL for groupname"); return gid; } struct group *resultObj; struct group wrkObj; char buf[2048]; getgrnam_r(groupname, &wrkObj, buf, sizeof(buf), &resultObj); if (resultObj == NULL) { Logger::LogWarn("WARN: GetGidFromName(): No group called '" + std::string(groupname) + "' is found."); } else { gid = resultObj->gr_gid; } trace.NOTE("GetGidFromName() returned: Group='" + std::string(groupname) + "'. GID=" + std::to_string(gid)); return gid; } void DaemonConf::SetPriv(uid_t uid, gid_t gid) { Trace trace(Trace::Daemon, "SetPriv"); std::string uidstr = std::to_string(uid); std::string gidstr = std::to_string(gid); if (0 == uid) { Logger::LogError("Error: unexpected user id " + uidstr + ". Do nothing."); return; } if (0 == gid) { Logger::LogError("Error: unexpected group id " + gidstr + ". Do nothing."); return; } int r2 = setgid(gid); if (r2) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("Error: fatal error. setgid() failed to set id " + gidstr + ". error: " + errstr); exit(1); } trace.NOTE("mdsd's groupid changed to " + gidstr); int r1 = setuid(uid); if (r1) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("Error: fatal error. setuid() failed to set id " + uidstr + ". error: " + errstr); exit(1); } else { trace.NOTE("mdsd's userid changed to id " + uidstr); } } /* Run mdsd in daemon mode by forking the child process. */ void DaemonConf::RunAsDaemon(const std::string & pidfile) { Trace trace(Trace::Daemon, "RunAsDaemon"); pid_t ppid = getpid(); pid_t pid = fork(); if (-1 == pid) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("Fork child process failed with -1. error: " + errstr); exit(1); } if (pid > 0) { Logger::LogError("Parent process " + std::to_string(ppid) + " exit. child process id=" + std::to_string(pid)); exit(0); } if (WritePid(pidfile) == false) { exit(1); } umask(0); // Create a new session for the child process pid_t sid = setsid(); if (sid < 0) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("child process setsid() returned " + std::to_string(sid) + ". error: " + errstr); exit(1); } if ((chdir("/")) < 0) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("Chdir() to root directory failed: " + errstr); exit(1); } close(STDIN_FILENO); close(STDOUT_FILENO); close(STDERR_FILENO); int uid = GetUidFromName(runAsUser); int gid = GetGidFromName(runAsGroup); if (uid >= 0 && gid >= 0) { SetPriv(uid, gid); } std::ostringstream msg; msg << "START mdsd daemon ver(" << Version::Version << ") pid(" << getpid() << ") uid(" << uid << ") gid (" << gid << ")" << std::endl; Logger::LogError(msg.str()); Logger::LogWarn(msg.str()); Logger::LogInfo(msg.str()); } bool DaemonConf::WritePid(const std::string & pidfile) { Trace trace(Trace::Daemon, "WritePid"); int fd = open(pidfile.c_str(), O_WRONLY|O_CREAT|O_CLOEXEC, 0644); MdsdUtil::FdCloser fdCloser(fd); if (fd < 0) { int errnum = errno; std::ostringstream buf; buf << "Error: failed to open or create Pid file: " << pidfile << ". " << MdsdUtil::GetErrnoStr(errnum); Logger::LogError(buf.str()); return false; } bool status = true; try{ MdsdUtil::WriteBufferAndNewline(fd, std::to_string(getpid())); } catch (const std::runtime_error & e) { Logger::LogError(std::string("Error writing pid file: ") + e.what()); status = false; } return status; } bool DaemonConf::Chown(const std::string& filepath) { bool isOK = true; uid_t uid = GetUidFromName(runAsUser); gid_t gid = GetGidFromName(runAsGroup); if (uid > 0 && gid > 0) { int r = chown(filepath.c_str(), uid, gid); if (r) { int errnum = errno; std::string errstr = MdsdUtil::GetErrnoStr(errnum); Logger::LogError("Error: Chown() failed. logfile='" + filepath + "' user='" + runAsUser + "' group='" + runAsGroup + "' . error: " + errstr); isOK = false; } } return isOK; } // vim: se ai sw=4 expandtab tabstop=4 : ================================================ FILE: Diagnostic/mdsd/mdsd/DaemonConf.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _DAEMONCONF_HH_ #define _DAEMONCONF_HH_ #include #include class DaemonConf { public: /* Run mdsd in daemon mode by forking the child process. */ static void RunAsDaemon(const std::string& pidfile); /* Change a file's user and group to the daemon runtime user/group. */ static bool Chown(const std::string& filepath); private: /* Get a given username's user id. If user is not found, return 0. */ static uid_t GetUidFromName(const char* username); /* Get a given groupname's groupid. If group is not found, return 0. */ static gid_t GetGidFromName(const char* groupname); /* Set daemon userid and groupid to given Ids. If uid or gid are 0, do nothing. */ static void SetPriv(uid_t uid, gid_t gid); /* Write final daemon process's process Id to pid file. */ static bool WritePid(const std::string & pidfile); private: constexpr static const char * runAsUser = "syslog"; constexpr static const char * runAsGroup = "syslog"; }; #endif ================================================ FILE: Diagnostic/mdsd/mdsd/DerivedEvent.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "DerivedEvent.hh" #include "MdsdConfig.hh" #include "Pipeline.hh" #include "CanonicalEntity.hh" #include "Logger.hh" #include "Trace.hh" #include "LocalSink.hh" DerivedEvent::DerivedEvent(MdsdConfig * config, const MdsEntityName &target, Priority prio, const MdsTime &interval, std::string source) : ITask(interval), _config(config), _target(target), _prio(prio), _head(nullptr), _tail(nullptr) { Trace trace(Trace::DerivedEvent, "DerivedEvent constructor"); // Find the source; make sure it exists _localSink = LocalSink::Lookup(source); if (! _localSink) { std::ostringstream msg; msg << "DerivedEvent " << target << " references undefined source " << source; Logger::LogError(msg.str()); throw std::runtime_error(msg.str()); } _localSink->SetRetentionPeriod(interval); } DerivedEvent::~DerivedEvent() { } // Initial start time is a few seconds past the end of the current interval MdsTime DerivedEvent::initial_start() { Trace trace(Trace::DerivedEvent, "DerivedEvent::initial_start"); MdsTime start; // Default constructor sets it to "now" start += interval(); start = start.Round(interval().to_time_t()); start += MdsTime(2 + random()%5, random()%1000000); if (trace.IsActive()) { std::ostringstream msg; msg << "Initial time for event: " << start; trace.NOTE(msg.str()); } return start; } void DerivedEvent::AddStage(PipeStage *stage) { Trace trace(Trace::DerivedEvent, "DerivedEvent::AddStage"); if (trace.IsActive()) { std::ostringstream msg; msg << "DerivedEvent " << this << " adding stage " << stage->Name(); trace.NOTE(msg.str()); } if (! _tail) { // This is the first stage in the pipeline; set the head to point here _head = stage; } else { // There's already a pipeline; make the old tail point to the newly-added stage _tail->AddSuccessor(stage); } // Either way, we have a new tail in the pipeline _tail = stage; } // Pull all the CanonicalEntity instances from the source that match the interval and send a dupe // into the processing pipeline; signal "done" after the last instance. void DerivedEvent::execute(const MdsTime& startTime) { Trace trace(Trace::DerivedEvent, "DerivedEvent::execute"); if (trace.IsActive()) { std::ostringstream msg; msg << "Start time " << startTime << ", end time " << startTime + interval(); trace.NOTE(msg.str()); } auto head = _head; _head->Start(startTime); _localSink->Foreach(startTime, interval(), [head](const CanonicalEntity& ce){ head->Process(new CanonicalEntity(ce)); }); _localSink->Flush(); // Tell the sink to do its housekeeping _head->Done(); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/DerivedEvent.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _DERIVEDEVENT_HH_ #define _DERIVEDEVENT_HH_ #include "ITask.hh" #include "MdsEntityName.hh" #include "Priority.hh" class MdsdConfig; class PipeStage; class LocalSink; class DerivedEvent : public ITask { public: DerivedEvent(MdsdConfig * config, const MdsEntityName &target, Priority prio, const MdsTime &interval, std::string source); // I want a move constructor... DerivedEvent(DerivedEvent &&orig); // But do not want a copy constructor nor a default constructor DerivedEvent(DerivedEvent &) = delete; DerivedEvent() = delete; virtual ~DerivedEvent(); const MdsEntityName & Target() const { return _target; } int FlushInterval() const { return _prio.Duration(); } void AddStage(PipeStage *); protected: // Subclasses *must* override the execute() method, which is called to perform the actual // time-scheduled class. virtual void execute(const MdsTime&); #if 0 // Dunno if I need these.... // Subclass gets notified via this callout when start() is called. If the subclass returns false, // the start operation aborts. In this case, start() can be called again; a failed startup is different // from a successful start followed by a cancel(). virtual bool on_start() { return true; } // Subclass gets notified when cancel() is called. virtual void on_cancel() { } #endif // We'll want the initial start time to be shortly after the end of the next "interval". // We'll add some hysteresis to that start time. virtual MdsTime initial_start(); private: MdsdConfig *_config; MdsEntityName _target; Priority _prio; LocalSink *_localSink; PipeStage *_head; PipeStage *_tail; }; #endif // _DERIVEDEVENT_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Engine.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Engine.hh" #include #include #include #include #include #include #include "MdsValue.hh" #include "TableSchema.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "EventJSON.hh" #include "CanonicalEntity.hh" #include "OmiTask.hh" #include "Trace.hh" #include "LocalSink.hh" #include "EtwEvent.hh" #include "Utility.hh" #include "EventHubUploaderMgr.hh" using std::string; Engine::Engine() : blackholeEvents(false), _startTime(time(0)), current_config(nullptr) { } Engine::~Engine() { } Engine* Engine::engine = 0; void Engine::SetConfiguration(MdsdConfig* newconfig) { Trace trace(Trace::ConfigLoad, "Engine::SetConfiguration"); Engine *current = Engine::GetEngine(); static std::mutex mtx; std::unique_lock lock(mtx); MdsdConfig *prev_config = current->current_config; current->current_config = newconfig; lock.unlock(); //current->PushSchemas(newconfig); newconfig->Initialize(); newconfig->StartScheduledTasks(); if (prev_config) { prev_config->SelfDestruct(900); // Old config will delete itself in 900 seconds } } #ifdef DOING_MEMCHECK void Engine::ClearConfiguration() { current_config->StopScheduledTasks(); delete current_config; current_config = nullptr; } #endif void Engine::ProcessEvent(EventJSON& event) { Trace trace(Trace::EventIngest, "Engine::ProcessEvent"); // Grab the config pointer at the beginning of processing; if the config gets // swapped while we're working, we won't care. The engine is careful to hold on // to previous MdsdConfig objects for a lengthy period of time after they're // swapped out. MdsdConfig* Config = GetConfig(); if (blackholeEvents) { return; } // Actual processing goes here // Listener() did basic validation before calling ProcessEvent() string Source(event.GetSource()); auto sink = LocalSink::Lookup(Source); if (!sink) { Logger::LogWarn("Received an event from source \"" + Source + "\" not used elsewhere in the active configuration"); return; } if (event.IsEtwEvent()) { EtwEvent etwevt(event); etwevt.Process(sink); return; } TableSchema* Schema = Config->GetSchema(Source); if (!Schema) { Logger::LogWarn("Received an event from source \"" + Source + "\" with no defined schema."); return; } // Build the CanonicalEntity to hold this event by running through the elements of the input event // and using the metadata in the schema to add columns auto ce = std::make_shared( Schema->Size() ); ce->SetPreciseTime(event.GetTimestamp()); ce->SetSchemaId(sink->SchemaId()); auto datum = event.data_begin(); TableSchema::const_iterator iter = Schema->begin(); while (datum != event.data_end() && iter != Schema->end()) { auto value = (*iter)->Convert(&(*datum)); if (!value) { std::ostringstream msg; msg << "Bad event (source " << Source << ", schema " << Schema->Name() << "): couldn't convert value for "; msg << (*iter)->Name(); msg << " to " << (*iter)->MdsType(); msg << ". Raw event: " << event; Logger::LogError(msg.str()); return; } ce->AddColumn((*iter)->Name(), value); ++datum; ++iter; } if (datum != event.data_end() || iter != Schema->end()) { std::stringstream msg; msg << "Event from source '" << Source << "' contained unexpected number of columns. "; msg << Source << " has " << event.data_count() << "; "; msg << "Schema '" << Schema->Name() << "' has " << Schema->Size() << "."; Logger::LogError(msg.str()); } else { // Add the CanonicalEntity object to the sink we found (above). sink->AddRow(ce); } } Engine* Engine::GetEngine() { if (!engine) { engine = new Engine(); } return engine; } bool Engine::GetConverter(const string& sourcetype, const string& targettype, typeconverter_t& converter) { std::string inOutType; inOutType.reserve(sourcetype.size() + 1 + targettype.size()); inOutType.append(sourcetype); inOutType.append(1, '/'); inOutType.append(targettype); return GetConverter(inOutType, converter); } bool Engine::GetConverter(const std::string & inOutType, typeconverter_t& converter) { auto iter = convertermap.find(inOutType); if (iter == convertermap.end()) { return false; } converter = iter->second; return true; } std::string Engine::ListConverters() { std::ostringstream msg; bool isFirst = true; for (const auto& item : convertermap) { if (isFirst) { isFirst = false; } else { msg << " "; } msg << "'" << item.first << "'"; } return msg.str(); } std::map Engine::convertermap = { { "bool/mt:bool", [](cJSON * src) -> MdsValue* { if (src->type == cJSON_False) return new MdsValue(false); if (src->type == cJSON_True) return new MdsValue(true); return 0; } }, { "str/mt:bool", [](cJSON * src) -> MdsValue* { if (cJSON_String == src->type && src->valuestring) { bool b = MdsdUtil::to_bool(src->valuestring); return new MdsValue(b); } return nullptr; } }, { "str/mt:wstr", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_String) ? ( new MdsValue(src->valuestring)) : 0; } }, { "double/mt:float64", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_Number) ? ( new MdsValue(src->valuedouble)) : 0; } }, { "str/mt:float64", [](cJSON * src) -> MdsValue* { if (cJSON_String == src->type && src->valuestring) { return new MdsValue(atof(src->valuestring)); } return nullptr; } }, { "int/mt:int32", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_Number) ? ( new MdsValue(long(src->valueint))) : 0; } }, { "str/mt:int32", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_String) ? ( new MdsValue(atol(src->valuestring))) : 0; } }, { "int/mt:int64", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_Number) ? ( new MdsValue(src->valueint)) : 0; } }, { "str/mt:int64", [](cJSON * src) -> MdsValue* { return (src->type == cJSON_String) ? ( new MdsValue(strtoll(src->valuestring, NULL, 10))) : 0; } }, { "int-timet/mt:utc", [](cJSON * src) -> MdsValue* { return MdsValue::time_t_to_utc(src); } }, { "double-timet/mt:utc", [](cJSON * src) -> MdsValue* { return MdsValue::double_time_t_to_utc(src); } }, { "str-rfc3339/mt:utc", [](cJSON * src) -> MdsValue* { return MdsValue::rfc3339_to_utc(src); } } }; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Engine.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _ENGINE_HH_ #define _ENGINE_HH_ #include #include #include #include #include #include #include "MdsValue.hh" #include "EventJSON.hh" #include "MdsSchemaMetadata.hh" #include "MdsEntityName.hh" class MdsdConfig; class Credentials; class Engine { public: ~Engine(); /// /// Get the singleton Engine instance. Not thread-safe for creation. /// static Engine* GetEngine(); /// Cause incoming events to be blackholed instead of being sent to MDS void BlackholeEvents() { blackholeEvents = true; } /// Process an event /// The event to be processed void ProcessEvent(EventJSON& event); /// /// Transfer a configuration into the active engine. The previous configuration will remain undeleted /// for a time; when the engine believes it's no longer in use, the engine will delete it. /// /// The new configuration object. static void SetConfiguration(MdsdConfig* newconfig); /// Fetch type converter. Returns false if sourcetype can't be converted to targettype /// Name of the original (JSON) type (e.g. "str", "int-timet") /// Name of the destination (MDS) type (e.g. "mt_bool") /// The type converter function, if one was found bool GetConverter(const std::string& sourcetype, const std::string& targettype, typeconverter_t& converter); /// Fetch type converter. Return false if inOutType cannot be found. /// Name pairs in the format of "jsonType/mdsType". (e.g. "bool/mt:bool") /// The type converter function, if one was found bool GetConverter(const std::string & inOutType, typeconverter_t& converter); /// Get a list of all configured type converters, suitable for display in error messages. static std::string ListConverters(); MdsdConfig* GetConfig() { return current_config; } /// Determines if the schema has been pushed for this account and tablename. Calling this /// method updates the cache of which schemas have been pushed. /// True if this is the first time NeedsPush has been called with these args. //bool NeedsPush(Credentials* creds, const MdsEntityName& target, const MdsSchemaMetadata*); #ifdef DOING_MEMCHECK void ClearPushedCache() { std::unique_lock lock(_schemaCacheMutex);_pushedEvents.clear(); } void ClearConfiguration(); #endif private: Engine(); static Engine* engine; bool blackholeEvents; time_t _startTime; MdsdConfig* current_config; static std::map convertermap; std::set > _pushedEvents; std::mutex _schemaCacheMutex; }; #endif //_ENGINE_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/EtwEvent.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EtwEvent.hh" #include "Logger.hh" #include "Trace.hh" #include "CanonicalEntity.hh" #include "LocalSink.hh" #include "MdsValue.hh" #include "Engine.hh" #include std::unordered_map EtwEvent::m_schemaIdMap; void EtwEvent::Process(LocalSink *sink) { Trace trace(Trace::EventIngest, "EtwEvent::Process"); if (!sink) { Logger::LogError("Error: unexpected NULL pointer for LocalSink."); return; } if (!m_event.IsEtwEvent()) { Logger::LogError("Error: the input event is not an ETW event. Do nothing."); return; } std::string guidstr = ParseGuid(); if (guidstr.empty()) { return; } int eventId = ParseEventId(); if (eventId < 0) { return; } unsigned int ncolumns = m_event.data_count() + 2; CanonicalEntity ce(ncolumns); ce.SetPreciseTime(m_event.GetTimestamp()); auto schemaId = GetSchemaId(guidstr, eventId); ce.SetSchemaId(schemaId); bool hasError = false; auto datum = m_event.data_begin(); while(datum != m_event.data_end()) { std::string name; auto mdsValue = ConvertData(&(*datum), name); if (!mdsValue) { hasError = true; break; } ce.AddColumn(name, mdsValue); ++datum; } if (!hasError) { sink->AddRow(ce, 0); } } // input cJSON is an array with 3 elements ["Name", "Value", "srctype/mdstype"] MdsValue* EtwEvent::ConvertData(cJSON* tuple, std::string & name) { if (!ValidateJSON(tuple, cJSON_Array)) { return nullptr; } const int ETW_TUPLE_SIZE = 3; int arraySize = cJSON_GetArraySize(tuple); if (ETW_TUPLE_SIZE != arraySize) { std::ostringstream ss; ss << "Error: invalid data format: expected ETW tuple size=" << ETW_TUPLE_SIZE << "; actual size=" << arraySize; Logger::LogError(ss.str()); return nullptr; } cJSON* head = tuple->child; if (!head || !GetJSONString(head, name)) { return nullptr; } head = head->next; cJSON * jvalue = head; if (!jvalue) { return nullptr; } head = head->next; std::string inOutType; if (!head || !GetJSONString(head, inOutType)) { return nullptr; } typeconverter_t converter; if (! Engine::GetEngine()->GetConverter(inOutType, converter)) { std::ostringstream ss; ss << "Error: failed to get type converter '" << inOutType << "'. Supported converters: " << Engine::ListConverters(); Logger::LogError(ss.str()); return nullptr; } return converter(jvalue); } std::string EtwEvent::ParseGuid() { std::string guidstr; if (!m_event.GetGuid(guidstr)) { std::ostringstream ss; ss << "Error: invalid event format: no expected '" << s_GUIDName << "' found. Do nothing."; Logger::LogError(ss.str()); return std::string(); } return guidstr; } int EtwEvent::ParseEventId() { int eventId = -1; if (!m_event.GetEventId(eventId)) { std::ostringstream ss; ss << "Error: invalid event format: no expected '" << s_EventIdName << "' found. Do nothing."; Logger::LogError(ss.str()); return -1; } return eventId; } bool EtwEvent::GetJSONString(cJSON* obj, std::string& value) { if (!ValidateJSON(obj, cJSON_String)) { return false; } value.assign(obj->valuestring); return true; } bool EtwEvent::ValidateJSON(cJSON* obj, int expectedType) { if (!obj) { Logger::LogError("Error: unexpected NULL pointer for cJSON object."); return false; } if (expectedType != obj->type) { std::ostringstream ss; ss << "Error: cJSON type: expected=" << expectedType << "; actual=" << obj->type << "."; Logger::LogError(ss.str()); return false; } return true; } SchemaCache::IdType EtwEvent::GetSchemaId(const std::string & guidstr, int eventid) { auto key = guidstr + std::to_string(eventid); const auto & iter = m_schemaIdMap.find(key); if (iter == m_schemaIdMap.end()) { auto id = SchemaCache::Get().GetId(); m_schemaIdMap[key] = id; return id; } else { return iter->second; } } ================================================ FILE: Diagnostic/mdsd/mdsd/EtwEvent.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __ETWEVENT_HH__ #define __ETWEVENT_HH__ #include #include #include "EventJSON.hh" #include "SchemaCache.hh" class LocalSink; class MdsValue; /// This class implements functions to parse ETW JSON events. Each JSON message will /// follow a format like below: /// {"TAG":"", /// "SOURCE":"ETW", /// "EVENTID" : , /// "GUID" : "", // NOTE: there is no {} around /// "DATA":[["name1","val1", "jsonType/mdsType"],["name2", "val2", "jsonType/mdsType"]]} class EtwEvent { public: EtwEvent(EventJSON& event) : m_event(event) {} ~EtwEvent() {} /// /// Process current event. Create a new CanonicalEntity object with the event /// data. Then save the CanonicalEntity into the given sink. /// If there is any error with the event data, nothing will be saved to sink. /// /// Sink to save CanonicalEntity void Process(LocalSink* sink); static const char* ETWName() { return s_ETWName; } static const char* GUIDName() { return s_GUIDName; } static const char* EventIDName() { return s_EventIdName; } /// /// Build and return a local table name given ETW GUID and EventID. /// static std::string BuildLocalTableName(const std::string & guid, int eventId) { return (std::string(s_ETWName) + "_" + guid + "_" + std::to_string(eventId)); } private: std::string ParseGuid(); int ParseEventId(); MdsValue* ConvertData(cJSON* item, std::string & name); bool GetJSONString(cJSON* obj, std::string& value); bool ValidateJSON(cJSON* obj, int expectedType); static SchemaCache::IdType GetSchemaId(const std::string & guidstr, int eventid); private: EventJSON& m_event; // Each ETW guid/eventid should correspond to a specific schema static std::unordered_map m_schemaIdMap; constexpr static const char* s_ETWName = "ETW"; constexpr static const char* s_GUIDName = "GUID"; constexpr static const char* s_EventIdName = "EVENTID"; }; #endif // __ETWEVENT_HH__ ================================================ FILE: Diagnostic/mdsd/mdsd/EventJSON.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventJSON.hh" #include extern "C" { #include } #include "EtwEvent.hh" #include "Logger.hh" using std::string; bool EventJSON::GetSource(string& value) { cJSON* item = cJSON_GetObjectItem(_event, "SOURCE"); if (!ValidateJSON("SOURCE", item, cJSON_String)) { return false; } else { value.assign(item->valuestring); if (value == EtwEvent::ETWName()) { if (!GetEtwEventSource(value)) { return false; } _isEtwEvent = true; } return true; } } string EventJSON::GetSource() { cJSON* item = cJSON_GetObjectItem(_event, "SOURCE"); if (!ValidateJSON("SOURCE", item, cJSON_String)) { return std::string(""); } else { std::string source = std::string(item->valuestring); if (source == EtwEvent::ETWName()) { if (!GetEtwEventSource(source)) { return source; } _isEtwEvent = true; } return source; } } bool EventJSON::GetGuid(std::string& value) { cJSON* guid = cJSON_GetObjectItem(_event, EtwEvent::GUIDName()); if (!ValidateJSON(EtwEvent::GUIDName(), guid, cJSON_String)) { return false; } value.assign(guid->valuestring); return true; } bool EventJSON::GetEventId(int & eventId) { cJSON* obj = cJSON_GetObjectItem(_event, EtwEvent::EventIDName()); if (!ValidateJSON(EtwEvent::EventIDName(), obj, cJSON_Number)) { return false; } eventId = obj->valueint; return true; } bool EventJSON::GetEtwEventSource(std::string& value) { std::string guidstr; int eventId = -1; if (!GetGuid(guidstr) || !GetEventId(eventId)) { return false; } value = EtwEvent::BuildLocalTableName(guidstr, eventId); return true; } bool EventJSON::GetTag(string& value) { cJSON* item = cJSON_GetObjectItem(_event, "TAG"); if (!ValidateJSON("TAG", item, cJSON_String)) { return false; } else { value.assign(item->valuestring); return true; } } bool EventJSON::ValidateJSON(const char* name, cJSON* obj, int expectedType) { if (!obj) { Logger::LogError("Error: unexpected NULL pointer for cJSON object."); return false; } if (expectedType != obj->type) { std::ostringstream ss; ss << "Error: "; if (name) { ss << "'" << name << "' "; } ss << "JSON type: expected=" << expectedType << "; actual=" << obj->type << "."; Logger::LogError(ss.str()); return false; } return true; } EventJSON::DataIterator EventJSON::data_begin() { cJSON* array = cJSON_GetObjectItem(_event, "DATA"); if (!array || !(array->child)) { return EventJSON::DataIterator((cJSON*)0); } else { return EventJSON::DataIterator(array->child); } } unsigned int EventJSON::data_count() { cJSON* array = cJSON_GetObjectItem(_event, "DATA"); if (array) { return cJSON_GetArraySize(array); } else { return 0; } } std::ostream& operator<<(std::ostream& os, const EventJSON& ev) { char *buf = cJSON_Print(ev._event); os << (const char*)buf; free(buf); return os; } // vim: se sw=8: ================================================ FILE: Diagnostic/mdsd/mdsd/EventJSON.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _EVENTCJSON_HH_ #define _EVENTCJSON_HH_ #include #include #include "Logger.hh" #include "MdsTime.hh" extern "C" { #include "cJSON.h" } #include class EventJSON { public: EventJSON(cJSON* event) : _event(event), _isEtwEvent(false) {} void PrintEvent() { char *buf = cJSON_Print(_event); Logger::LogInfo(buf); free(buf); } bool GetSource(std::string& source); std::string GetSource(); bool GetTag(std::string& tag); const MdsTime& GetTimestamp() const { return _timestamp; } bool GetGuid(std::string& guid); bool GetEventId(int & eventId); bool IsEtwEvent() const { return _isEtwEvent; } class DataIterator : public std::iterator { private: cJSON* _current; public: DataIterator(cJSON* item) : _current(item) {} DataIterator(const DataIterator& other) : _current(other._current) {} DataIterator& operator++() { _current = _current->next; return *this; } DataIterator operator++(int) { DataIterator tmp(*this); operator++(); return tmp; } bool operator==(const DataIterator& other) { return _current == other._current; } bool operator!=(const DataIterator& other) { return _current != other._current; } cJSON& operator*() { return *_current; } cJSON* operator->() { return _current; } }; DataIterator data_begin(); DataIterator data_end() { return DataIterator((cJSON*)0); } unsigned int data_count(); friend std::ostream& operator<<(std::ostream& os, const EventJSON& ev); private: EventJSON(); bool GetEtwEventSource(std::string& value); bool ValidateJSON(const char* name, cJSON* obj, int expectedType); cJSON* _event; MdsTime _timestamp; bool _isEtwEvent; }; #endif //_EVENTCJSON_HH_ // vim: se sw=8: ================================================ FILE: Diagnostic/mdsd/mdsd/ExtensionMgmt.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ExtensionMgmt.hh" #include "Logger.hh" #include "Utility.hh" #include "Trace.hh" #include "MdsdExtension.hh" #include "MdsdConfig.hh" #include "CmdLineConverter.hh" #include #include #include extern "C" { #include #include #include } std::map ExtensionList::_extlistByName; std::map ExtensionList::_extlistByPid; std::unordered_set ExtensionList::_killList; std::mutex ExtensionList::_listmutex; std::mutex ExtensionList::_klmutex; bool ExtensionMetaData::operator==(const ExtensionMetaData & other) const { if (Name == other.Name && CommandLine == other.CommandLine && Body == other.Body && AlterLocation == other.AlterLocation) { return true; } return false; } ExtensionInfo::ExtensionInfo() : StopTimer(nullptr), StopTimerCancelled(false) { } ExtensionInfo::~ExtensionInfo() { if (StopTimer) { StopTimerCancelled = true; delete StopTimer; StopTimer = nullptr; } } std::string ExtensionInfo::GetStatus() const { return ExtensionInfo::StatusToString(Status); } std::map ExtensionInfo::_statusMap = { { ExtStatus::NORMAL, "NORMAL" }, { ExtStatus::BAD, "BAD" }, { ExtStatus::KILLING, "KILLING" }, { ExtStatus::EXIT, "EXIT" } }; std::string ExtensionInfo::StatusToString(ExtStatus s) { const auto &iter = _statusMap.find(s); if (_statusMap.end() == iter) { return "UNKNOWN"; } return iter->second; } size_t ExtensionList::GetSize() { std::unique_lock lock(_listmutex); return _extlistByName.size(); } bool ExtensionList::AddItem(ExtensionInfo * extObj) { Trace trace(Trace::Extensions, "ExtensionList::AddItem"); if (!extObj) { Logger::LogError("Error: unexpected NULL value for ExtensionInfo object."); return false; } const std::string & extname = extObj->MetaData.Name; if (MdsdUtil::IsEmptyOrWhiteSpace(extname)) { Logger::LogError("Error: unexpected empty or whitespace value for ExtensionName"); return false; } std::unique_lock lock(_listmutex); // search for the item. If found, delete the old one. const auto & iter = _extlistByName.find(extname); if (iter != _extlistByName.end()) { ExtensionInfo *oldExtObj = iter->second; delete oldExtObj; oldExtObj = nullptr; } _extlistByName[extname] = extObj; assert(0 != extObj->Pid); _extlistByPid[extObj->Pid] = extObj; trace.NOTE("Successfully added ExtensionInfo object with Name='" + extname + "'"); return true; } void ExtensionList::AddPid(pid_t pid) { Trace trace(Trace::Extensions, "ExtensionList::AddPid"); std::unique_lock lock(_klmutex); if (0 < _killList.count(pid)) { Logger::LogError("Error: duplicate pid found: " + std::to_string(pid)); } else { _killList.insert(pid); } } std::unordered_set ExtensionList::GetAndClearPids() { Trace trace(Trace::Extensions, "ExtensionList::GetAndClearPids"); std::unique_lock lock(_klmutex); std::unordered_set r = _killList; _killList.clear(); return r; } ExtensionInfo * ExtensionList::GetItem(const std::string & extname) { Trace trace(Trace::Extensions, "ExtensionList::GetItem(extname)"); if (MdsdUtil::IsEmptyOrWhiteSpace(extname)) { Logger::LogError("Error: unexpected empty or whitespace value for ExtensionName"); return nullptr; } ExtensionInfo * obj = nullptr; std::unique_lock lock(_listmutex); const auto & iter = _extlistByName.find(extname); if (iter != _extlistByName.end()) { obj = iter->second; trace.NOTE("Got ExtensionInfo object: '" + extname + "'"); } else { trace.NOTE("ExtensionInfo is not found: '" + extname + "'."); } return obj; } ExtensionInfo * ExtensionList::GetItem(pid_t extPid) { Trace trace(Trace::Extensions, "ExtensionList::GetItem(pid_t)"); if (0 >= extPid) { Logger::LogError("Error: unexpected value for pid: " + std::to_string(extPid)); return nullptr; } ExtensionInfo * obj = nullptr; std::unique_lock lock(_listmutex); const auto & iter = _extlistByPid.find(extPid); if (iter != _extlistByPid.end()) { obj = iter->second; trace.NOTE("Got ExtensionInfo with pid=" + std::to_string(extPid)); } else { trace.NOTE("ExtensionInfo is not found with pid=" + std::to_string(extPid)); } return obj; } bool ExtensionList::UpdateItem(pid_t oldpid, pid_t newpid) { Trace trace(Trace::Extensions, "ExtensionList::UpdateItem"); assert(0 < oldpid); assert(0 < newpid); bool resultOK = true; std::unique_lock lock(_listmutex); const auto & iter = _extlistByPid.find(oldpid); if (iter != _extlistByPid.end()) { ExtensionInfo *obj = iter->second; _extlistByPid.erase(iter); _extlistByPid[newpid] = obj; trace.NOTE("Extension is updated: from pid " + std::to_string(oldpid) + " to pid " + std::to_string(newpid)); } else { resultOK = false; Logger::LogError("Extension is not found with pid=" + std::to_string(oldpid)); } return resultOK; } bool ExtensionList::DeleteItem(const std::string & extname) { Trace trace(Trace::Extensions, "ExtensionList::DeleteItem"); if (MdsdUtil::IsEmptyOrWhiteSpace(extname)) { Logger::LogError("Error: unexpected empty or whitespace for ExtensionName"); return false; } bool resultOK = true; std::unique_lock lock(_listmutex); const auto & iter = _extlistByName.find(extname); if (iter != _extlistByName.end()) { ExtensionInfo * obj = iter->second; _extlistByName.erase(iter); _extlistByPid.erase(obj->Pid); lock.unlock(); trace.NOTE("Deleted item: '" + extname + "'"); delete obj; obj = nullptr; resultOK = true; } else { trace.NOTE("Extension is not found: '" + extname + "'"); resultOK = false; } return resultOK; } bool ExtensionList::DeleteItems(const std::set& extnames) { Trace trace(Trace::Extensions, "ExtensionList::DeleteItems"); if (0 == extnames.size()) { return true; } bool resultOK = true; std::unique_lock lock(_listmutex); for(const auto & extname : extnames) { const auto & iter = _extlistByName.find(extname); if (iter != _extlistByName.end()) { ExtensionInfo * obj = iter->second; _extlistByName.erase(iter); _extlistByPid.erase(obj->Pid); trace.NOTE("Deleted item: '" + extname + std::string("'")); delete obj; obj = nullptr; } else { trace.NOTE("Extension is not found: '" + extname + std::string("'")); resultOK = false; } } return resultOK; } void ExtensionList::DeleteAllItems() { Trace trace(Trace::Extensions, "ExtensionList::DeleteAllItems"); std::unique_lock lock(_listmutex); for (auto x : _extlistByPid) { delete x.second; } _extlistByPid.clear(); _extlistByName.clear(); } void ExtensionList::ForeachExtension(const std::function& fn) { Trace trace(Trace::Extensions, "ExtensionList::ForeachExtension"); std::unique_lock lock(_listmutex); for (const auto & kv : _extlistByName) { trace.NOTE(std::string("Walking ExtensionInfo with name='") + kv.first + "'"); fn(kv.second); } } ExtensionMgmt * ExtensionMgmt::_extInstance = nullptr; ExtensionMgmt* ExtensionMgmt::GetInstance() { if (!_extInstance) { _extInstance = new ExtensionMgmt(); if(!_extInstance->InitSem()) { delete _extInstance; _extInstance = nullptr; } } return _extInstance; } ExtensionMgmt::ExtensionMgmt() : _extsemInitOK(false) { } ExtensionMgmt::~ExtensionMgmt() { Trace trace(Trace::Extensions, "ExtensionMgmt::~ExtensionMgmt"); if (_extsemInitOK) { if (-1 == sem_destroy(&_extsem)) { std::string errstr = MdsdUtil::GetErrnoStr(errno); Logger::LogError("Error: sem_destroy() failed: " + errstr); } } } bool ExtensionMgmt::InitSem() { Trace trace(Trace::Extensions, "ExtensionMgmt::InitSem"); if (-1 == sem_init(&_extsem, 0, 0)) { std::string errstr = MdsdUtil::GetErrnoStr(errno); Logger::LogError("Error: sem_init() failed: " + errstr); _extsemInitOK = false; return false; } _extsemInitOK = true; return true; } bool ExtensionMgmt::StartExtensions(MdsdConfig * config) { Trace trace(Trace::Extensions, "ExtensionMgmt::StartExtensions"); if (!config) { trace.NOTE("MdsdConfig* is NULL. Do nothing."); return true; } bool resultOK = false; try { ExtensionMgmt* extmgmt = GetInstance(); if (extmgmt) { std::set extlistInConfig; resultOK = extmgmt->StartExtensionsFromConfig(config, extlistInConfig); resultOK = resultOK && extmgmt->StopObsoleteExtensions(extlistInConfig); } } catch(const std::exception & ex) { Logger::LogError(std::string("Error: StartExtensions failed: ") + ex.what()); resultOK = false; } return resultOK; } void ExtensionMgmt::StartExtensionsAsync(MdsdConfig * config) { if (!config) { return; } // If there is no old and new extension, do nothing if (0 == config->GetNumExtensions() && 0 == ExtensionList::GetSize()) { return; } static std::future lastTask; static std::mutex mtx; try { // multiple threads may call this function when automatic configuration mgr // and main thread starts up std::lock_guard lock(mtx); if (lastTask.valid()) { if (!lastTask.get()) { Logger::LogError("Previous StartExtensions() failed."); } } lastTask = std::async(std::launch::async, StartExtensions, config); } catch(const std::system_error& ex) { Logger::LogError(std::string("Error: std::async failed calling 'StartExtensions': ") + ex.what()); } } bool ExtensionMgmt::StartExtensionsFromConfig( MdsdConfig * config, std::set& extlistInConfig) { Trace trace(Trace::Extensions, "ExtensionMgmt::StartExtensionsFromConfig"); bool resultOK = true; std::vector changedList; // key is old extension's pid. std::map newDataList; std::function Visitor = [this,&trace,&extlistInConfig,&resultOK,&changedList,&newDataList](MdsdExtension * extObj) { const std::string & extname = extObj->Name(); const std::string & cmdline = extObj->GetCmdLine(); const std::string & body = extObj->GetBody(); const std::string & alterLocation = extObj->GetAlterLocation(); assert(false == MdsdUtil::IsEmptyOrWhiteSpace(extname)); assert(false == MdsdUtil::IsEmptyOrWhiteSpace(cmdline)); extlistInConfig.insert(extname); // check with ExtensionList ExtensionInfo* oldExtInfo = ExtensionList::GetItem(extname); if (!oldExtInfo) { resultOK = resultOK && StartExtension(extname, cmdline, body, alterLocation); } else { ExtensionMetaData newMetaData(extname, cmdline, body, alterLocation); bool sameMetaData = (oldExtInfo->MetaData == newMetaData); if (!sameMetaData) { trace.NOTE("Found new metadata for " + extname); changedList.push_back(oldExtInfo); newDataList[oldExtInfo->Pid] = newMetaData; } else { trace.NOTE("No metadata were changed for " + extname); } } }; config->ForeachExtension(Visitor); if (0 < changedList.size()) { resultOK = resultOK && RestartChangedExtensions(changedList, newDataList); } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } // terminate current Extension processes. each process will send SIGCHLD, which // will be handled in signal handler. The extension will be deleted in the signal handler. bool ExtensionMgmt::RestartChangedExtensions( const std::vector & changedList, const std::map & newDataList) { Trace trace(Trace::Extensions, "ExtensionMgmt::RestartChangedExtensions"); bool resultOK = true; if (0 == changedList.size()) { return resultOK; } assert(changedList.size() == newDataList.size()); for (const auto & ext : changedList) { StopExtension(ext); } trace.NOTE("Wait for all changed extensions to be stopped ..."); for (size_t i = 0; i < newDataList.size(); i++) { bool extStopOK = WaitForAnyExtStop(); if (extStopOK) { std::unordered_set changedPids = ExtensionList::GetAndClearPids(); for (const auto & pid : changedPids) { trace.NOTE("GetAndClearPids(): pid=" + std::to_string(pid)); } resultOK = StartAllChangedExts(changedPids, newDataList); } } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::WaitForAnyExtStop() { Trace trace(Trace::Extensions, "ExtensionMgmt::WaitForAnyExtStop"); bool resultOK = true; struct timespec ts; if (-1 == clock_gettime(CLOCK_REALTIME, &ts)) { resultOK = false; } else { ts.tv_sec += EXT_TERMINATE_GRACE_SECONDS + 1; int waitstatus = 0; time_t semStartTime = time(0); while((waitstatus = sem_timedwait(&_extsem, &ts)) == -1 && EINTR == errno) { semStartTime = time(0); continue; } int waiterrno = errno; if (-1 == waitstatus) { if (ETIMEDOUT == waiterrno) { long waitTime = (long)(time(0) - semStartTime); Logger::LogError("Error: sem_timedwait() timed out after " + std::to_string(waitTime) + " seconds."); } else { std::string errstr = MdsdUtil::GetErrnoStr(waiterrno); Logger::LogError("Error: sem_timedwait() failed. Error string: " + errstr); } resultOK = false; } else { trace.NOTE("sem_timedwait() succeeded."); } } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::StartAllChangedExts( const std::unordered_set changedPids, const std::map & newDataList) const { Trace trace(Trace::Extensions, "ExtensionMgmt::StartAllChangedExts"); bool resultOK = true; for (const auto & pid : changedPids) { const auto & iter = newDataList.find(pid); if (newDataList.end() == iter) { Logger::LogError("Error: old extension pid is not found: " + std::to_string(pid)); resultOK = false; } else { ExtensionMetaData metadata = iter->second; assert(metadata.Name.empty() == false); resultOK = resultOK && StartOneChangedExt(pid, metadata); } } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::StartOneChangedExt(pid_t changedPid, const ExtensionMetaData & metadata) const { Trace trace(Trace::Extensions, "ExtensionMgmt::StartOneChangedExt"); bool resultOK = true; // only start new one when old one was terminated. if (-1 == waitpid(changedPid, NULL, WNOHANG) && ECHILD == errno) { trace.NOTE(metadata.Name + " with pid " + std::to_string(changedPid) + " was terminated. Start new one."); resultOK = resultOK && StartExtension(metadata); } else { Logger::LogError("Error: " + metadata.Name + " with pid " + std::to_string(changedPid) + " was not terminated properly."); resultOK = false; } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::StopObsoleteExtensions(const std::set & extlistInConfig) const { Trace trace(Trace::Extensions, "ExtensionMgmt::StopObsoleteExtensions"); if (0 == ExtensionList::GetSize()) { return true; } std::set obsoleteExtNames; std::unordered_set obsoleteExtObjs; std::function Visitor = [&extlistInConfig,&obsoleteExtNames,&obsoleteExtObjs](ExtensionInfo * extObj) { assert(nullptr != extObj); if (extlistInConfig.find(extObj->MetaData.Name) == extlistInConfig.end()) { obsoleteExtNames.insert(extObj->MetaData.Name); obsoleteExtObjs.insert(extObj); } }; ExtensionList::ForeachExtension(Visitor); // The extensions must be stopped first before being deleted bool resultOK = true; for (const auto & extObj : obsoleteExtObjs) { resultOK = resultOK && StopExtension(extObj); } resultOK = resultOK && ExtensionList::DeleteItems(obsoleteExtNames); trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::StopAllExtensions() { Trace trace(Trace::Extensions, "ExtensionMgmt::StopAllExtensions"); size_t nitems = ExtensionList::GetSize(); if (0 == nitems) { return true; } bool resultOK = true; unsigned int nexists = 0; std::function StopExtFunc = [this,&nexists,&resultOK](ExtensionInfo * extObj) { assert(nullptr != extObj); if (-1 != waitpid(extObj->Pid, NULL, WNOHANG)) { nexists++; } resultOK = resultOK && StopExtension(extObj); }; ExtensionList::ForeachExtension(StopExtFunc); trace.NOTE("Found " + std::to_string(nexists) + " running extensions. Wait for them to finish."); for (size_t i = 0; i < nexists; i++) { resultOK = resultOK & WaitForAnyExtStop(); } ExtensionList::DeleteAllItems(); trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::MaskSignal(bool isBlock, int signum) const { Trace trace(Trace::Extensions, "ExtensionMgmt::MaskSignal"); sigset_t ss; std::string errmsg = ""; int errnum = 0; if (-1 == sigemptyset(&ss)) { errnum = errno; errmsg = "Error: sigemptyset() failed."; } else { if (-1 == sigaddset(&ss, signum)) { errnum = errno; errmsg = "Error: sigaddset() failed on signal: " + std::to_string(signum); } else { int how = isBlock? SIG_BLOCK : SIG_UNBLOCK; if (-1 == sigprocmask(how, &ss, NULL)) { errnum = errno; errmsg = "Error: sigprocmask() failed."; } } } bool resultOK = true; if (errmsg != "") { errmsg += " Error string: " + MdsdUtil::GetErrnoStr(errnum); Logger::LogError(errmsg); resultOK = false; } return resultOK; } bool ExtensionMgmt::StartExtension(const ExtensionMetaData & metaData) const { return StartExtension(metaData.Name, metaData.CommandLine, metaData.Body, metaData.AlterLocation); } bool ExtensionMgmt::StartExtension( const std::string & extName, const std::string & cmdline, const std::string & body, const std::string & alterLocation ) const { Trace trace(Trace::Extensions, "ExtensionMgmt::StartExtension"); bool resultOK = true; ExtensionInfo * oldExtInfo = ExtensionList::GetItem(extName); if (oldExtInfo) { sleep(EXT_RETRY_WAIT_SECONDS); } if (!MaskSignal(true, SIGCHLD)) { return false; } CmdLineConverter cconverter(cmdline); char** cargv = cconverter.argv(); // use pipe to send child error to parent int pipefds[2]; if (-1 == pipe(pipefds)) { Logger::LogError("Error: pipe() failed: Error string: " + MdsdUtil::GetErrnoStr(errno)); return false; } // Use FD_CLOEXEC so that if exec() succeeds, fd will be closed automatically. if (fcntl(pipefds[1], F_SETFD, fcntl(pipefds[1], F_GETFD) | FD_CLOEXEC)) { Logger::LogError("Error: fcntl() failed: Error string: " + MdsdUtil::GetErrnoStr(errno)); return false; } pid_t pid = fork(); int forkerr = errno; if (-1 == pid) { Logger::LogError("Error: fork() failed: Error string: '" + MdsdUtil::GetErrnoStr(forkerr) + "'."); return false; } if (0 == pid) { // child process close(pipefds[0]); int childerr = 0; if (!MdsdUtil::IsEmptyOrWhiteSpace(body)) { if (-1 == setenv(BODYENV, body.c_str(), 1)) { childerr = errno; } } if (0 == childerr) { childerr = UnblockSignals(); if (0 == childerr) { std::string fullpath = alterLocation + "/" + cargv[0]; execvp(fullpath.c_str(), cargv); // child has error if it reaches here childerr = errno; } } // send error code to parent if (write(pipefds[1], &childerr, sizeof(int)) < 0) { Logger::LogError("Error: write() failed: Error string: '" + MdsdUtil::GetErrnoStr(errno) + "'."); } _exit(0); } // parent process close(pipefds[1]); // read child error if any. int readcount = 0; int childerr = 0; while (-1 == (readcount = read(pipefds[0], &childerr, sizeof(int)))) { if (EAGAIN != errno && EINTR != errno) { break; } } bool childFailed = false; if (readcount && childerr > 0) { Logger::LogError("Error: create " + extName + " process failed. pid=" + std::to_string(pid) + ". Error: " + MdsdUtil::GetErrnoStr(childerr)); childFailed = true; } else { trace.NOTE("Created process " + extName + ": cmdline=" + cmdline + "; pid=" + std::to_string(pid)); } resultOK = resultOK && UpdateExtensionList(oldExtInfo, extName, cmdline, body, alterLocation, pid, childFailed); resultOK = resultOK && MaskSignal(false, SIGCHLD); return resultOK; } bool ExtensionMgmt::UpdateExtensionList( ExtensionInfo * oldExtInfo, const std::string & extName, const std::string & cmdline, const std::string & body, const std::string & alterLocation, pid_t pid, bool extFailed) const { Trace trace(Trace::Extensions, "ExtensionMgmt::UpdateExtensionList"); bool resultOK = true; if (!oldExtInfo) { trace.NOTE("Get a new extension definition. Add it to cache."); ExtensionInfo * extInfo = new ExtensionInfo(); extInfo->MetaData.Name = extName; extInfo->MetaData.CommandLine = cmdline; extInfo->MetaData.Body = body; extInfo->MetaData.AlterLocation = alterLocation; extInfo->Pid = pid; extInfo->StartTime = time(NULL); extInfo->Status = ExtensionInfo::NORMAL; extInfo->RetryCount = extFailed? (EXT_MAX_RETRIES+1) : 0; resultOK = ExtensionList::AddItem(extInfo); if (!resultOK) { delete extInfo; extInfo = nullptr; } } else { pid_t oldpid = oldExtInfo->Pid; trace.NOTE("Get existing extension. Update its pid from " + std::to_string(oldpid) + " to " + std::to_string(pid)); oldExtInfo->Pid = pid; oldExtInfo->StartTime = time(NULL); oldExtInfo->Status = ExtensionInfo::NORMAL; resultOK = ExtensionList::UpdateItem(oldpid, pid); } return resultOK; } int ExtensionMgmt::UnblockSignals() const { Trace trace(Trace::Extensions, "ExtensionMgmt::UnblockSignals"); int sigerr = 0; sigset_t ss; if (-1 == sigfillset(&ss)) { sigerr = errno; Logger::LogError("Error: sigfillset() failed. Error string: " + MdsdUtil::GetErrnoStr(sigerr)); } else { if (-1 == sigprocmask(SIG_UNBLOCK, &ss, NULL)) { sigerr = errno; Logger::LogError("Error: sigprocmask() failed. Error string: " + MdsdUtil::GetErrnoStr(sigerr)); } } return sigerr; } bool ExtensionMgmt::StopExtension(ExtensionInfo * extObj) const { Trace trace(Trace::Extensions, "ExtensionMgmt::StopExtension"); if (!extObj) { trace.NOTE("ExtensionInfo object is NULL. Do nothing."); return true; } pid_t extpid = extObj->Pid; std::string extname = extObj->MetaData.Name; trace.NOTE("Stopping " + extname + " pid=" + std::to_string(extpid) + " status=" + extObj->GetStatus()); bool stopOK = false; bool isPsExist = false; ExtensionInfo::ExtStatus oldStatus = extObj->Status; assert(ExtensionInfo::ExtStatus::EXIT != oldStatus); extObj->Status = ExtensionInfo::ExtStatus::KILLING; trace.NOTE("Set " + extname + "'s status to be KILLING. Pid=" + std::to_string(extpid)); if (ExtensionInfo::ExtStatus::NORMAL == oldStatus || ExtensionInfo::ExtStatus::BAD == oldStatus) { stopOK = SendSignalToProcess(extpid, SIGINT, &isPsExist); } if (isPsExist) { trace.NOTE("Set timer to KillProcessByForce ..."); extObj->StopTimer = new boost::asio::deadline_timer(crossplat::threadpool::shared_instance().service()); extObj->StopTimer->expires_from_now(boost::posix_time::seconds(EXT_TERMINATE_GRACE_SECONDS)); extObj->StopTimer->async_wait(boost::bind(&ExtensionMgmt::KillProcessByForce, this, extpid, boost::asio::placeholders::error)); } trace.NOTE("Finished with success = " + MdsdUtil::ToString(stopOK)); return stopOK; } void ExtensionMgmt::CatchSigChld(int signo) { Trace trace(Trace::Extensions, "ExtensionMgmt::CatchSigChld"); trace.NOTE(std::string("Caught signal=") + std::to_string(signo) + " : " + std::string(strsignal(signo))); assert(SIGCHLD == signo); pid_t chldpid = 0; int waitpiderr = 0; bool haschild = false; while(true) { chldpid = waitpid((pid_t)-1, NULL, WNOHANG); waitpiderr = errno; trace.NOTE("waitpid() returned id=" + std::to_string(chldpid) + "\n"); if (0 < chldpid) { UpdateStoppedExtension(chldpid); haschild = true; } else { break; } } if (-1 == chldpid && ECHILD == waitpiderr && !haschild) { if (-1 == sem_post(&_extsem)) { std::string errstr = MdsdUtil::GetErrnoStr(errno); trace.NOTE("Error: CatchSigchld: sem_post() failed: " + errstr); } } } bool ExtensionMgmt::UpdateStoppedExtension(pid_t extpid) { Trace trace(Trace::Extensions, "ExtensionMgmt::UpdateStoppedExtension"); ExtensionInfo * extObj = ExtensionList::GetItem(extpid); if (!extObj) { Logger::LogError("no ExtensionInfo object found in cache for pid=" + std::to_string(extpid)); return false; } ExtensionInfo::ExtStatus status = extObj->Status; std::string extname = extObj->MetaData.Name; trace.NOTE("Extension pid=" + std::to_string(extpid) + "; Status=" + ExtensionInfo::StatusToString(status)); bool resultOK = true; assert(ExtensionInfo::ExtStatus::NORMAL == status || ExtensionInfo::ExtStatus::KILLING == status); if (ExtensionInfo::ExtStatus::NORMAL == status) { resultOK = HandleExtensionFailure(extObj); } else if (ExtensionInfo::ExtStatus::KILLING == status) { trace.NOTE("Change extension status to EXIT. Delete it from cache. Call sem_post()."); extObj->Status = ExtensionInfo::ExtStatus::EXIT; resultOK = resultOK && ExtensionList::DeleteItem(extname); ExtensionList::AddPid(extpid); if (-1 == sem_post(&_extsem)) { std::string errstr = MdsdUtil::GetErrnoStr(errno); trace.NOTE("Error: UpdateStoppedExtension: sem_post() failed: " + errstr); resultOK = false; } } else { resultOK = false; Logger::LogError("Unexpected extension status. expected=NORMAL/KILLING; actual=" + extObj->GetStatus()); } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::HandleExtensionFailure(ExtensionInfo * extObj) { Trace trace(Trace::Extensions, "ExtensionMgmt::HandleExtensionFailure"); bool resultOK = true; if (!extObj) { Logger::LogError("Unexpected nullptr for extension object."); return false; } extObj->Status = ExtensionInfo::ExtStatus::BAD; unsigned int extlife = static_cast((time(NULL) - extObj->StartTime)); if (EXT_RETRY_TIMEOUT_SECONDS >= extlife) { extObj->RetryCount++; } else { extObj->RetryCount = 0; } trace.NOTE("Extension last life: " + std::to_string(extlife) + " seconds, retry count: " + std::to_string(extObj->RetryCount)); if (EXT_MAX_RETRIES >= extObj->RetryCount) { trace.NOTE("Meet retry criteria. Restart extension."); resultOK = resultOK && StartExtension(extObj->MetaData); } else { trace.NOTE("Exceed max retries. Stop retrying. Delete it from cache."); extObj->Status = ExtensionInfo::ExtStatus::EXIT; resultOK = resultOK && ExtensionList::DeleteItem(extObj->MetaData.Name); } trace.NOTE("Finished with success = " + MdsdUtil::ToString(resultOK)); return resultOK; } bool ExtensionMgmt::KillProcessByForce(pid_t pid, const boost::system::error_code& error) const { Trace trace(Trace::Extensions, "ExtensionMgmt::KillProcessByForce"); bool resultOK = true; TRACEINFO(trace, "pid=" << pid); if (boost::asio::error::operation_aborted == error) { trace.NOTE("Operation is aborted. Do nothing."); resultOK = false; } else { ExtensionInfo * obj = ExtensionList::GetItem(pid); if (obj->StopTimerCancelled) { trace.NOTE("Extension with pid " + std::to_string(pid) + " is already cancelled. Stop further action."); } else { bool isPsExist = true; resultOK = SendSignalToProcess(pid, SIGKILL, &isPsExist); } } return resultOK; } bool ExtensionMgmt::SendSignalToProcess(pid_t pid, int signum, bool *pIsPsExist) const { Trace trace(Trace::Extensions, "ExtensionMgmt::SendSignalToProcess"); assert(0 < pid); assert(0 < signum); bool resultOK = true; trace.NOTE("Start to send signal " + std::to_string(signum) + " to pid " + std::to_string(pid)); (*pIsPsExist) = true; if (-1 == kill(pid, signum)) { int killerr = errno; std::string errstr = MdsdUtil::GetErrnoStr(errno); if (ESRCH == killerr) { trace.NOTE("process was not found with pid=" + std::to_string(pid)); (*pIsPsExist) = false; } else { Logger::LogError("Error: failed to send signal. Error string: " + errstr); resultOK = false; } } else { trace.NOTE("Sucessfully sent signal."); } return resultOK; } extern "C" { void CatchSigChld(int signo) { ExtensionMgmt *e = ExtensionMgmt::GetInstance(); if (e) { e->CatchSigChld(signo); } } void CleanupExtensions() { ExtensionMgmt *e = ExtensionMgmt::GetInstance(); if (e) { e->StopAllExtensions(); } } } ================================================ FILE: Diagnostic/mdsd/mdsd/ExtensionMgmt.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _EXTENSIONINFO_HH_ #define _EXTENSIONINFO_HH_ #include #include #include #include #include #include #include #include #include #include #include #include extern "C" { #include #include } /// /// Keep track of an extension's meta data. Any of these data change /// will mean the extension definition changed. /// struct ExtensionMetaData { std::string Name; std::string CommandLine; std::string Body; std::string AlterLocation; ExtensionMetaData() { } ExtensionMetaData( const std::string & name, const std::string & cmdline, const std::string & body, const std::string & loc ) : Name(name), CommandLine(cmdline), Body(body), AlterLocation(loc) {} /// /// Compare this extension's meta data with some other meta data. /// Return true if they are the same, return false otherwise. /// bool operator==(const ExtensionMetaData & other) const; }; /// /// Keep track an extension process's information. /// class ExtensionInfo { public: enum ExtStatus { NORMAL, // A new extension. BAD, // An extension starts to run but failed in the middle. KILLING, // An extension killed by SIGINT, or killed externally. EXIT, // An extension stopped and killed already. UNKNOWN // Unknown. Should be never in this status. }; // An extension's metadata. ExtensionMetaData MetaData; // Process id of the extension process. pid_t Pid = 0; // Process start time in number of seconds in UTC. time_t StartTime = 0; // Number of times the extension is retried since last reset. unsigned int RetryCount = 0; // Extension status. ExtStatus Status = UNKNOWN; // asio timer to kill the extension by force. boost::asio::deadline_timer * StopTimer; // whether the extension timer is already cancelled or not. std::atomic StopTimerCancelled; ExtensionInfo(); ~ExtensionInfo(); /// /// Get the string format of the status. /// std::string GetStatus() const; /// /// Get the string format of the status. /// static std::string StatusToString(ExtStatus s); private: static std::map _statusMap; }; /// /// Keep track of all extension processes. /// class ExtensionList { public: /// /// Get number of items. /// static size_t GetSize(); /// /// Add an item to the list. If the item already exists, free the memory of /// existing one and add the new one. /// Return true if no error. Return false if the input is invalid. /// static bool AddItem(ExtensionInfo * extObj); /// /// Get an item given its name. /// Return the object pointer. /// Return nullptr if not found or given name is invalid. /// The caller shouldn't free the object pointer. /// static ExtensionInfo * GetItem(const std::string & extname); /// /// Get an item given its process id. /// Return the object pointer. /// Return nullptr if not found or given pid is invalid. /// The caller shouldn't free the object pointer. /// static ExtensionInfo * GetItem(pid_t extPid); /// /// Update an existing item's pid. /// Return true if success, false if any error. /// static bool UpdateItem(pid_t oldpid, pid_t newpid); /// /// Delete an item from the list. /// Return true if the item is actually deleted. /// Return false if the given name is invalid or not found. /// static bool DeleteItem(const std::string & extname); /// /// Delete a set of items with given names. /// Return true if all items are actually deleted, or set is empty. /// Return false if any item is not found. /// static bool DeleteItems(const std::set& extnames); static void DeleteAllItems(); /// /// Use a given function to iterate over each extension object. /// static void ForeachExtension(const std::function& fn); /// /// Add a pid to the pid set. /// static void AddPid(pid_t pid); /// /// Get all pids of the pid set. Clear the original one. /// static std::unordered_set GetAndClearPids(); private: static std::map _extlistByName; static std::map _extlistByPid; static std::mutex _listmutex; /// Store a list of PIDs that needs to be killed because their /// meta data are changed. static std::unordered_set _killList; static std::mutex _klmutex; }; class MdsdExtension; class MdsdConfig; /// /// Use configuration to create new extension processes, then manage /// the extension processes. /// class ExtensionMgmt { public: /// /// Free all resources. /// ~ExtensionMgmt(); /// /// Get a singleton instance. /// static ExtensionMgmt* GetInstance(); /// /// Start all extensions given a config synchronously. It will also /// stop any obsolete extension. /// Return true if success; Return false for any error. /// static bool StartExtensions(MdsdConfig * config); /// /// Calls StartExtensions() in async. /// static void StartExtensionsAsync(MdsdConfig * config); /// /// Stop all extensions. /// Return true for success, false for any error. /// bool StopAllExtensions(); /// /// Defines SIGCHLD signal handler, which is from child extension process. It will /// - release child process resources. /// - change extension object status. /// - update the extension object in ExtensionList. /// void CatchSigChld(int signo); private: ExtensionMgmt(); /// /// Define semaphore to synchronize between stopped extensions /// (handled in SIGCHLD signal handler) and creating new ones (in main thread) /// sem_t _extsem; /// /// True if semaphore is initialized properly, false if any error. /// bool _extsemInitOK; /// /// Singleton instance. /// static ExtensionMgmt * _extInstance; /// /// Environment name for extension. Extension uses it to read the value defined in /// static constexpr const char* BODYENV = "MON_EXTENSION_BODY"; /// /// The grace period in number of seconds for the extension process to /// terminate itself before it is killed by force. Because mdsd service's /// grace period is 30-second, make it shorter than that. /// static const unsigned int EXT_TERMINATE_GRACE_SECONDS = 20; /// /// The maximum number of retries to start extension within /// given window seconds. /// static const unsigned int EXT_MAX_RETRIES = 3; /// /// Numbe of seconds to wait before retrying the extension /// static const unsigned int EXT_RETRY_WAIT_SECONDS = 5; /// /// Extension restart retry timeout in number of seconds. /// If the time difference is bigger than this window, reset /// extension's RetryCount to be 0. /// static const unsigned int EXT_RETRY_TIMEOUT_SECONDS = 60; /// /// Initialize semaphore. Return true if no error; return false for any error. /// bool InitSem(); /// /// Start all extensions defined in a config. It won't stop any extension. /// It will return the extension names defined in the config in extlistInConfig. /// Return true for success, false for any error. /// bool StartExtensionsFromConfig( MdsdConfig * config, std::set & extlistInConfig); /// /// Restart all extensions whose meta data were changed. /// Return true if success, false for any error. /// List of changed extensions. /// The meta data for the changed extensions. Key is old extension pid. /// bool RestartChangedExtensions(const std::vector & changedList, const std::map & newDataList); /// /// Wait until any extension's change status SIGCHLD caught, or until timed out /// after SEM_WAIT_SECONDS seconds. /// Return true if success, false if error or timed out. /// bool WaitForAnyExtStop(); /// /// Start all extensions whose meta data were changed. /// Return true for success, false for any error. /// bool StartAllChangedExts(const std::unordered_set changedPids, const std::map & newDataList) const; /// /// Start one extension instance whose meta data were changed. /// Return true for success, false for any error. /// bool StartOneChangedExt(pid_t changedPid, const ExtensionMetaData & metadata) const; /// /// Any extension that's not in given set is obsolete. /// For each obsolete extension, delete it from ExtensionList and Stop it. /// Return true if no error is found. Otherwise, return false. /// bool StopObsoleteExtensions(const std::set & extlistInConfig) const; /// /// Block or unblock a given signal to the process. /// bool MaskSignal(bool isBlock, int signum) const; /// /// Attempt to start a given extension process. /// Return true if it starts OK, return false for any error. /// If starting OK, the extensionInfo object will be added to ExtensionList. Its memory /// will be managed there. /// bool StartExtension( const std::string & extName, const std::string & cmdline, const std::string & body, const std::string & alterLocation) const; /// /// Start an extension process given its meta data. /// Return true if it starts OK, return false for any error. /// bool StartExtension(const ExtensionMetaData & metaData) const; /// /// Either create a new ExtensionInfo object or update existing one in /// the extension list. If an extension failed to be created, it should not be /// retried. /// bool UpdateExtensionList( ExtensionInfo * oldExtInfo, const std::string & extName, const std::string & cmdline, const std::string & body, const std::string & alterLocation, pid_t pid, bool extFailed) const; /// /// Stop an extension process. It won't remove the ExtensionInfo item from ExtensionList. /// bool StopExtension(ExtensionInfo * extObj) const; /// /// Update the information of extension given its pid. /// Return true for success, false for error. /// bool UpdateStoppedExtension(pid_t extpid); /// /// Handle extension that fails itself. Either retry it or delete it forever based on its status. /// Return true if success, false if any error. /// bool HandleExtensionFailure(ExtensionInfo * extObj); /// /// This is to unblock all signal mask. For example, in child process, child process /// may use this function to unblock signal mask inherited from parent process. /// Return errno. /// int UnblockSignals() const; /// /// Kill a process by sending it SIGKILL. It doesn't validate whether /// the process is actually killed or not. /// Return true if signal is sent out properly. /// Return false if signal is not sent out, or the operation is aborted. /// bool KillProcessByForce(pid_t pid, const boost::system::error_code& error) const; /// /// Send signal signum to process id pid. /// Return whether the process exists or not through pIsPsExist. /// Return true if signal is sent out properly, false if error. If process doesn't exist, /// also return true. /// process id /// signal number /// Return whether the process exists or not /// bool SendSignalToProcess(pid_t pid, int signum, bool *pIsPsExist) const; }; #endif // _EXTENSIONINFO_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/FileSink.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "FileSink.hh" #include #include #include "CanonicalEntity.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include "RowIndex.hh" #include "Trace.hh" #include "MdsdMetrics.hh" #include "StoreType.hh" #include "Logger.hh" // FileSink uses the name of the sink as the pathname. If the path isn't absolute, we make it // relative to /tmp. // // By design, each invocation of this constructor creates an independent // instance with its own ostream. They're all opened in append mode, which should keep simultaneous // writes from being interleaved. If writes are large enough that they become interleaved, this // design will need to be revisited. Perhaps Batches should hold reference-counted pointers to their // sinks, so that the destruction of the last batch instance pointing to a sink causes the sink to // be destroyed. Add to that a map from filename to a weak pointer to the filesink; when the FileSink // destructor is called (when the last strong refcounted pointer goes away), the destructor removes the weak // pointer from the map. // FileSink::FileSink(const std::string &name) : IMdsSink(StoreType::Type::File), _name(name) { Trace trace(Trace::Local, "FileSink::Constructor"); // Construct _path based on default directory if (name[0] != '/') { _path = "/tmp/"; // Make a relative path into an absolute path } _path += name; // Do a quick sanity check to make sure the file can be opened. Allow any exception // from Open() to propagate upwards. Open(); Close(); } // When destroying, remove from the global list of file sinks. No need to close the file; the // destructor for ostream is defined as closing the file. FileSink::~FileSink() { Trace trace(Trace::Local, "FileSink::Destructor"); } void FileSink::Open() { if (! _file.is_open()) { _file.open(_path, std::ofstream::app); // Open for write in append mode if (!_file) { std::system_error e(errno, std::system_category(), "Failed to open " + _path + " for append"); Logger::LogError("Error: " + e.code().message() + " - " + e.what()); throw e; } } } // Write the row, in readable form, to the output file. Add a timestamp. Don't bother with // async disk file writes; the primary goal of the FileSink is testability, so stability and certainty // is more important than absolute performance. void FileSink::AddRow(const CanonicalEntity &row, const MdsTime &) { std::lock_guard lock(_mutex); #if BUFFER_ALL_DATA std::ostringstream msg; msg << MdsTime::Now() << "\t" << row << "\n"; items.push_back(std::move(msg.str())); #else try { Open(); // If you emit std::endl, that does a flush, which isn't what we want. _file << MdsTime::Now() << "\t" << row << "\n"; } catch (const std::exception&) { } #endif } void FileSink::Flush() { Trace trace(Trace::Local, "FileSink::Flush"); std::lock_guard lock(_mutex); #if BUFFER_ALL_DATA try { Open(); for (const auto& item : items) { _file << item; } } catch (const std::exception&) { } items.clear(); #endif Close(); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/FileSink.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _FILESINK_HH_ #define _FILESINK_HH_ #include "IMdsSink.hh" #include #include #include #include #include #include #include #include "MdsTime.hh" #include "MdsEntityName.hh" #include "CanonicalEntity.hh" class FileSink : public IMdsSink { public: FileSink(const std::string&); // Private constructor; must be called with _mapMutex locked virtual ~FileSink(); virtual bool IsFile() const { return true; } virtual void AddRow(const CanonicalEntity&, const MdsTime&); virtual void Flush(); private: const std::string _name; std::string _path; std::ofstream _file; std::mutex _mutex; std::vector items; void Open(); void Close() { try { _file.close(); } catch (const std::exception&) { } } }; #endif // _FILESINK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/IMdsSink.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "IMdsSink.hh" #include "XTableSink.hh" #include "LocalSink.hh" #include "FileSink.hh" #include "XJsonBlobSink.hh" #include "Trace.hh" #include "Logger.hh" #include #include IMdsSink* IMdsSink::CreateSink(MdsdConfig * config, const MdsEntityName &target, const Credentials* creds) { Trace trace(Trace::ConfigLoad, "IMdsSink::CreateSink"); switch (target.GetStoreType()) { case StoreType::XTable: return new XTableSink(config, target, creds); case StoreType::Local: return new LocalSink(target.Basename()); case StoreType::File: return new FileSink(target.Basename()); case StoreType::XJsonBlob: return new XJsonBlobSink(config, target, creds); default: std::ostringstream msg; msg << "Attempt to create sink of unknown type for target " << target; Logger::LogError(msg.str()); trace.NOTE(msg.str()); throw std::logic_error(msg.str()); } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/IMdsSink.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _IMDSSINK_HH_ #define _IMDSSINK_HH_ #include #include "StoreType.hh" #include "MdsTime.hh" #include "MdsEntityName.hh" class CanonicalEntity; class Credentials; class MdsdConfig; class IMdsSink { public: virtual bool IsXTable() const { return false; } virtual bool IsBond() const { return false; } virtual bool IsXJsonBlob() const { return false; } virtual bool IsLocal() const { return false; } virtual bool IsFile() const { return false; } static IMdsSink* CreateSink(MdsdConfig *, const MdsEntityName &target, const Credentials*); virtual void AddRow(const CanonicalEntity&, const MdsTime&) = 0; // This is a pure virtual class virtual void Flush() = 0; virtual void ValidateAccess() {} // Throws if credentials cannot be used to access the target // May have desireable initialization side-effect(s) IMdsSink() = delete; // No default constructor IMdsSink(const IMdsSink&) = delete; // No copy constructor IMdsSink& operator=(const IMdsSink&) = delete; // No copy assignment IMdsSink(IMdsSink&&) = delete; // No Move constructor virtual IMdsSink& operator=(IMdsSink&&) = delete; // No Move assignment virtual ~IMdsSink() {} void SetRetentionPeriod(const MdsTime & period) { if (period > _retentionPeriod) _retentionPeriod = period; } const MdsTime RetentionPeriod() const { return _retentionPeriod; } time_t RetentionSeconds() const { return _retentionPeriod.to_time_t(); } StoreType::Type Type() const { return _type; } protected: IMdsSink(StoreType::Type t) : _type(t), _retentionPeriod(0) {} private: StoreType::Type _type; MdsTime _retentionPeriod; }; #endif // _IMDSSINK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/ITask.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ITask.hh" #include "Logger.hh" #include "Trace.hh" #include ITask::ITask(const MdsTime &interval) : _interval(interval), _timer(crossplat::threadpool::shared_instance().service()), _cancelled(false) { assert(interval != MdsTime(0)); Trace trace(Trace::Scheduler, "ITask Constructor"); } ITask::~ITask() { } void ITask::start() { using namespace boost::posix_time; Trace trace(Trace::Scheduler, "ITask::Start"); // Call subclass on_start() method. Last minute initialization happens there, and the subclass // can call the whole thing off by returning false. if (! on_start()) { Logger::LogError("Task refused startup"); return; } MdsTime start { initial_start() }; time_t spanSeconds = _interval.to_time_t(); _intervalStart = start.Round(spanSeconds) - _interval; if (trace.IsActive()) { std::ostringstream msg; msg << this << " requested start@" << start << " for interval beginning at " << _intervalStart; msg << " of size " << spanSeconds << " seconds"; trace.NOTE(msg.str()); } _nextTime = start.to_ptime(); _timer.expires_at(_nextTime); _timer.async_wait(boost::bind(&ITask::DoWork, this, boost::asio::placeholders::error)); } void ITask::cancel() { Trace trace(Trace::Scheduler, "ITask::Cancel"); if (_cancelled) { trace.NOTE("Already cancelled; ignoring"); return; } else { std::lock_guard lock(_mutex); _cancelled = true; _timer.cancel(); } on_cancel(); // Called with mutex NOT locked } void ITask::DoWork(const boost::system::error_code& error) { Trace trace(Trace::Scheduler, "ITask::DoWork"); MdsTime start { _intervalStart }; if (error == boost::asio::error::operation_aborted) { // If the timer was cancelled, we have to assume the entire configuration may have been // deleted; don't touch it. When an MdsdConfig object is told to self-destruct, it first // cancels all timer-driven actions, then it waits some period of time, then it actually // deletes the object. When the timers are cancelled, the handlers are called with the // cancellation message. The MdsdConfig object is *probably* still valid, and as long // as the timer isn't rescheduled, all should be well. But I'm playing it safe here // and assuming an explicit cancel operation means "the config is gone". // // Of course, if the MdsdConfig is deleted, all the associated objects, including this // very ITask object, get deleted as well. Thus, the "don't touch nothin'" rule. trace.NOTE("Timer cancelled"); return; } else { std::lock_guard lock(_mutex); if (error || _cancelled) { return; } trace.NOTE("Rescheduling"); _intervalStart += _interval; _nextTime = _nextTime + _interval.to_duration(); _timer.expires_at(_nextTime); _timer.async_wait(boost::bind(&ITask::DoWork, this, boost::asio::placeholders::error)); } // Note that, as written, we do NOT hold the lock here; our use of the class instance // needs to be readonly. If that changes, revisit this locking pattern. execute(start); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/ITask.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _ITASK_HH_ #define _ITASK_HH_ #include #include #include #include #include "MdsTime.hh" class MdsdConfig; // Interface for regularly-scheduled tasks. When an ITask is created, the interval at which it should be // executed is set. Once the ITask::start() method is invoked, a timer is set to cause the virtual ITask::on_start() // method to be invoked at the requested frequency (every _interval seconds), until the ITask::cancel() method // is invoked. class ITask { public: // Task should run every _interval_ seconds ITask(const MdsTime &interval); // I want a move constructor... ITask(ITask &&orig); // But do not want a copy constructor nor a default constructor ITask(ITask &) = delete; ITask() = delete; virtual ~ITask(); // Requests that this repeating task be scheduled for execution void start(); // Requests that the task be stopped. Any execution already in progress (or for which the timer has already // tripped but execution still awaits scheduling on a thread) will take place, but the _cancelled boolean // can be observed. // // Once cancelled(), a task cannot be restarted; that is, you cannot call start() again. You must instead // create a new instance of the task object. This is due to the boost deadline timer not being restartable, // which itself arises from enabling cancellation in the first place, near as I can tell. void cancel(); MdsTime interval() const { return _interval; } protected: // Subclasses *must* override the execute() method, which is called to perform the actual // time-scheduled class. virtual void execute(const MdsTime&) = 0; // Subclass gets notified via this callout when start() is called. If the subclass returns false, // the start operation aborts. In this case, start() can be called again; a failed startup is different // from a successful start followed by a cancel(). virtual bool on_start() { return true; } // Subclass gets notified when cancel() is called. virtual void on_cancel() { } // When start() is called, a time for the initial task invocation must be determined. // By default, wait 2-7 second; the randomness prevents all the tasks from being started // at the same time when running through all tasks scheduled for a given config. Any // derived class can override this function, e.g. if the task needs to run within 5 seconds // of the beginning of the next "interval". virtual MdsTime initial_start() { return MdsTime::Now() + MdsTime(2 + random()%5, random()%1000000); } // Subclass can check to see if cancellation has been requested bool is_cancelled() { return _cancelled; } private: MdsTime _interval; std::mutex _mutex; boost::asio::deadline_timer _timer; boost::posix_time::ptime _nextTime; bool _cancelled; MdsTime _intervalStart; void DoWork(const boost::system::error_code& error); }; #endif // _ITASK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/IdentityColumns.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _IDENTITYCOLUMNS_HH_ #define _IDENTITYCOLUMNS_HH_ #include #include #include using ident_col_t = std::pair; using ident_vect_t = std::vector; #endif // _IDENTITYCOLUMNS_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/LADQuery.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Logger.hh" #include "Trace.hh" #include "LADQuery.hh" #include "CanonicalEntity.hh" #include "Utility.hh" #include #include #include namespace Pipe { const std::string LADQuery::_name { "LADQuery" }; void LADQuery::FullAggregate::Sample(double value) { _total += value; _last = value; if (_count) { if (value > _maximum) _maximum = value; if (value < _minimum) _minimum = value; } else { _maximum = _minimum = value; } _count += 1; } // The core DerivedEvent task pulls entities from the source that fall within the just-completed // time window (based on duration). The LADQuery looks like this: // 1) Group by the value in the nameAttrName column; mark that column to be preserved // 2) Compute aggregate stats for the value in the valueAttrName column and pass the single aggregate row down the pipe // 3) Add a column with the specified partition key // 4) Send the CanonicalEntity down the pipe twice, once with each of the two distinct row keys as defined for the LAD query // // The strings are pass-by-value; the initializers use move semantics to move the copies into the member variables. // If the compiler can determine the actual parameters are temporaries, or about to go out of scope, it can optimize // the copy away, thus giving us the move semantics we actually want. Worst case, we're still doing only a single // copy (to prepare the passed values). LADQuery::LADQuery(std::string valueAN, std::string nameAN, std::string pkey, std::string uuid) : _valueAttrName(std::move(valueAN)), _nameAttrName(std::move(nameAN)), _pkey(std::move(pkey)), _uuid(std::move(uuid)), _lastSampleTime(0), _startOfSample(0) { } void LADQuery::Start(const MdsTime QIbase) { // Prepare to process all the rows in this sample period _lastSampleTime = QIbase; _startOfSample = QIbase; // Do whatever the base class needs PipeStage::Start(QIbase); } void LADQuery::Process(CanonicalEntity *item) { Trace trace(Trace::QueryPipe, "LADQuery::Process"); // Get the value of the nameAttrName column // Look in the savedStats map for the FullAggregate object associated with that name // if there is none, make one and then use it // Update the FullAggregate based on the value of the valueAttrName column MdsValue* value = item->Find(_valueAttrName); MdsValue* name = item->Find(_nameAttrName); if (!(value && name)) { trace.NOTE("Name or Value column missing; skipping entity"); } else if (! name->IsString()) { Logger::LogWarn("Name column is not a string"); } else if (! value->IsNumeric()) { Logger::LogWarn("Value column is not numeric"); } else { _savedStats[*(name->strval)].Sample(value->ToDouble()); _lastSampleTime = item->PreciseTime(); } delete item; // No longer needed; we've updated the correct aggregation object } void LADQuery::Done() { Trace trace(Trace::DerivedEvent, "LADQuery::Done"); // For each savedStats object in the map: // Build a new CE with the full set of stats // Add the _partitionKey to the CE // Dupe the CE // Put one of the LAD keys on the original CE; put the other key on the dupe // Send both rows to the successor pipe // // Call Done on the successor pipe std::string descendingTicks = MdsdUtil::ZeroFill(MdsTime::MaxDateTimeTicks - _startOfSample.to_DateTime(), 19); for (const auto & iter : _savedStats) { auto entity = new CanonicalEntity(10); entity->SetPreciseTime(MdsTime::Now()); // For the "time" field in Jsonblob entity->AddColumn(_nameAttrName, new MdsValue(iter.first)); entity->AddColumn("Total", new MdsValue(iter.second.Total())); entity->AddColumn("Minimum", new MdsValue(iter.second.Minimum())); entity->AddColumn("Maximum", new MdsValue(iter.second.Maximum())); entity->AddColumn("Average", new MdsValue(iter.second.Average())); entity->AddColumn("Count", new MdsValue(iter.second.Count())); entity->AddColumn("Last", new MdsValue(iter.second.Last())); entity->AddColumn("PartitionKey", _pkey); auto dupe = new CanonicalEntity(*entity); dupe->SetPreciseTime(entity->PreciseTime()); // For the "time" field in Jsonblob std::string metric = EncodeAndHash(iter.first, 256); std::ostringstream key1, key2; key1 << descendingTicks << "__" << metric; key2 << metric << "__" << descendingTicks; if (_uuid.length()) { key1 << "__" << _uuid; key2 << "__" << _uuid; } trace.NOTE("Aggregation rowkey " + key1.str()); entity->AddColumn("RowKey", key1.str()); PipeStage::Process(entity); trace.NOTE("Aggregation rowkey (dupe) " + key2.str()); dupe->AddColumn("RowKey", key2.str()); dupe->SetSourceType(CanonicalEntity::SourceType::Duplicated); PipeStage::Process(dupe); } PipeStage::Done(); // Pass the "done" signal to the next stage // Empty the map now to free memory, rather than waiting for the next Start() call _savedStats.clear(); } std::string LADQuery::EncodeAndHash(const std::string &name, size_t limit) { Trace trace(Trace::DerivedEvent, "LADQuery::EncodeAndHash"); trace.NOTE("EncodeAndHash(\"" + name + "\")"); std::string result; for (const char c : name) { if (isalpha(c) || isdigit(c)) { result.push_back(c); } else { std::ostringstream encoded; encoded << ":" << std::hex << std::uppercase << std::setw(4) << std::setfill('0') << (unsigned short)c; result.append(encoded.str()); } } if (result.size() > limit) { trace.NOTE("Hashing required..."); auto hash = MdsdUtil::MurmurHash64(result, 0); std::ostringstream hashstr; const size_t charcnt = sizeof(hash)*2; hashstr << "|" << std::hex << std::setw(charcnt) << std::setfill('0') << hash; result.replace(limit - (1 + charcnt), std::string::npos, hashstr.str()); } trace.NOTE("Encoded to \"" + result + "\""); return result; } // End of namespace } // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/LADQuery.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _LADQUERY_HH_ #define _LADQUERY_HH_ #include "Pipeline.hh" #include "MdsEntityName.hh" #include #include #include // Pipe stages must implement the Process method. // Pipe stages that retain data must implement the Done method. // Pipe stages must implement a constructor, which can have any parameters that might be required. namespace Pipe { class LADQuery : public PipeStage { public: // Deliberately call-by-value. LADQuery(std::string valueAN, std::string nameAN, std::string pkey, std::string uuid); void Start(const MdsTime QIbase); void Process(CanonicalEntity *); const std::string& Name() const { return _name; } void Done(); private: static const std::string _name; const std::string _valueAttrName; const std::string _nameAttrName; const std::string _pkey; const std::string _uuid; MdsTime _lastSampleTime; MdsTime _startOfSample; std::string EncodeAndHash(const std::string &, size_t); // Contains aggregated stats on a counter during processing of a LADQuery class FullAggregate { public: FullAggregate() : _total(0.0), _minimum(DBL_MAX), _maximum(-DBL_MAX), _last(0.0), _count(0) {} void Sample(double value); double Total() const { return _total; } double Minimum() const { return _minimum; } double Maximum() const { return _maximum; } double Last() const { return _last; } long Count() const { return _count; } double Average() const { return _count?(_total / _count):0.0; } private: double _total; double _minimum; double _maximum; double _last; long _count; }; // Holds all the instances of aggregation stats during processing. // Cleared after each run. Bad things will happen if multiple threads // call LADQuery::Process, which really shouldn't happen. std::unordered_map _savedStats; }; } #endif // _LADQUERY_HH_ // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/LinuxMdsConfig.xsd ================================================  ================================================ FILE: Diagnostic/mdsd/mdsd/Listener.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Listener.hh" #include "Logger.hh" #include "Engine.hh" #include "EventJSON.hh" #include "Trace.hh" #include "Utility.hh" #include #include #include #include #include extern "C" { #include "cJSON.h" #include } // Set default checkpoint time to 1/2 the default dupe-detection window. The default window // is one hour. unsigned int Listener::checkpointSeconds = 60 * 60 / 2; // Thread startproc. If and when the specific ProcessLoop() method returns, cleanup and exit. // The pthread interface requires this method to both accept and return a void*. void * Listener::handler(void * obj) { Trace trace(Trace::EventIngest, "Listener::handler"); // Create a shared_ptr to own the Listener object auto listener = std::shared_ptr((Listener*)obj); trace.NOTE("Start timer for " + listener->Name()); // Start the timer running listener->Timer().expires_from_now(boost::posix_time::seconds(checkpointSeconds)); listener->Timer().async_wait(boost::bind(&Listener::timerhandler, listener, TimerTask::rotate)); auto result = listener->ProcessLoop(); trace.NOTE("Returned from ProcessLoop for " + listener->Name()); listener->Shutdown(); return result; } // Upon receiving an indication from the sender that the session is over, call this method to // shutdown our end of it. This method should be called synchronously on the listening thread; // the only expected race is against the timer handler. Once we set _finished, the timer handler // will do its own cleanup the next time it runs. If the timer handler runs between the moment // Shutdown sets _finished and the moment it calls _timer.cancel(), we'll be cancelling a timer // that wasn't set, but that is perfectly fine. void Listener::Shutdown() { Trace trace(Trace::EventIngest, "Listener::Shutdown"); trace.NOTE("Shutting down " + Name()); close(clientfd); _finished = true; _timer.cancel(); } // Parse one or more objects out of a range of characters in the half-open range [start, end). // Return a pointer to the character immediately following the last object successfully parsed // (and skipping any trailing whitespace). // This return value is guaranteed to be <= end. If no message was successfully parsed, a null // pointer (0) will be returned. // The parser assumes *end is a NUL byte so it can treat the range as a C string. const char * Listener::ParseBuffer(const char* start, const char* end) { Trace trace(Trace::EventIngest, "Listener::ParseBuffer"); const char * parse_end = 0; const char * lkg_parse_end = 0; cJSON * event; if (*end != '\0') { std::ostringstream msg; size_t n = end - start + 1; msg << "ParseBuffer " << Name() << " got a non-NUL terminated range, length = " << n << "\n"; DumpBuffer(msg, start, end); throw Listener::exception(msg); } while ((start < end) && (event = cJSON_ParseWithOpts(start, &parse_end, 0))) { if (parse_end > end) { std::ostringstream msg; msg << "ParseBuffer found an object longer than the input buffer. Start " << (void *)start << ", end "; msg << (void *)end << ", parse_end " << (void *)parse_end << "\n"; if (*end != '\0') { msg << "Range is no longer NUL-terminated.\n"; } DumpBuffer(msg, start, end); throw Listener::exception(msg); } bool status = TryParseEvent(event) || TryParseEcho(event); if (!status) { LogBadJSON(event, Name() + " ignored unknown JSON message"); } // Free the parsed event cJSON_Delete(event); // Advance past the object we just parsed, skip trailing whitespace. // I don't really have to do this; cJSON handles leading whitespace. But it's better // if I can consume a full buffer; that reduces copying of useless characters. while ((parse_end < end) && (isspace(*parse_end))) { parse_end++; } start = lkg_parse_end = parse_end; } if (lkg_parse_end != parse_end) { TRACEINFO(trace, "parse_end (" << (void*)parse_end << ") != lkg (" << (void*)lkg_parse_end << ")"); } return lkg_parse_end; } bool Listener::TryParseEvent(cJSON* event) { Trace trace(Trace::EventIngest, "Listener::TryParseEvent"); cJSON* jsTAG = cJSON_GetObjectItem(event, "TAG"); if (!jsTAG || jsTAG->type != cJSON_String) { return false; } cJSON* jsSOURCE = cJSON_GetObjectItem(event, "SOURCE"); cJSON* jsDATA = cJSON_GetObjectItem(event, "DATA"); if ((jsSOURCE && jsSOURCE->type == cJSON_String) && (jsDATA && jsDATA->type == cJSON_Array)) { // That's plenty of validation for now. if (trace.IsActive()) { char *rendering = cJSON_Print(event); auto len = strlen(rendering); TRACEINFO(trace, "Got event from source " << jsSOURCE->valuestring << " of total size " << len); if (trace.IsAlsoActive(Trace::IngestContents)) { std::ostringstream msg; std::string body(rendering, (len>1024?1024:len)); msg << Name() << " received JSON event " << body; if (len > 1024) { msg << " ... }"; } trace.NOTE(msg.str()); } free(rendering); } if (IsNewTag(jsTAG)) { // Process the event... EventJSON evt(event); Engine::GetEngine()->ProcessEvent(evt); } // Inform the client we've processed the event EchoTag(jsTAG->valuestring); } else { LogBadJSON(event, Name() + " received incomplete JSON-encoded event"); } return true; } bool Listener::TryParseEcho(cJSON* event) { Trace trace(Trace::EventIngest, "Listener::TryParseEcho"); cJSON* jsECHO = cJSON_GetObjectItem(event, "ECHO"); if (jsECHO && jsECHO->type == cJSON_String) { EchoTag(jsECHO->valuestring); return true; } return false; } void Listener::LogBadJSON(cJSON* event, const std::string& prefix) { char *rendering = cJSON_Print(event); Logger::LogError(prefix + " {" + rendering + "}"); free(rendering); } // Echo the tag, followed by a newline, back to the client. void Listener::EchoTag(char * tagptr) { try { MdsdUtil::WriteBufferAndNewline(clientfd, tagptr); } catch (const MdsdUtil::would_block& e) { std::ostringstream msg; msg << "Event source tag-reader is slow; dropping tag " << tagptr; Logger::LogWarn(msg); } catch (const std::system_error& e) { if (EPIPE == e.code().value()) { throw Listener::exception(std::string("Event sender closed connection: ") + e.what()); } else { Logger::LogError(std::string("Listener failed to echo TAG: ") + e.what()); } } catch (const std::runtime_error& e) { Logger::LogError(std::string("Listener failed to echo TAG: ") + e.what()); } } Listener::Listener(int fd) : clientfd(fd), tagsAgedOut(0), tagsOldest(new tag_set()), tagsOld(new tag_set()), tagsCurr(new tag_set()), _timer(crossplat::threadpool::shared_instance().service()), _finished(false) { Trace trace(Trace::EventIngest, "Listener::Listener"); std::ostringstream msg; msg << this; _name = msg.str(); trace.NOTE("Constructed Listener " + Name()); } Listener::~Listener() { Trace trace(Trace::EventIngest, "Listener::~Listener"); Logger::LogWarn("Closing fd in ~Listener()"); trace.NOTE("Destroying Listener " + Name()); close(clientfd); if (tagsAgedOut) { delete tagsAgedOut; tagsAgedOut = 0; } if (tagsOldest) { delete tagsOldest; tagsOldest = nullptr; } if (tagsOld) { delete tagsOld; tagsOld = nullptr; } if (tagsCurr) { delete tagsCurr; tagsCurr = nullptr; } } bool Listener::IsNewTag(cJSON* jsTAG) { Trace trace(Trace::EventIngest, "Listener::IsNewTag"); if (nullptr == jsTAG) { trace.NOTE("Got a NULL JSON object pointer"); return false; } else if (nullptr == jsTAG->valuestring) { trace.NOTE("JSON object had NULL valuestring"); return false; } else if (0 == *(jsTAG->valuestring)) { trace.NOTE("JSON object had zero-length valuestring"); return false; } trace.NOTE("Checking tag \"" + std::string(jsTAG->valuestring) + "\""); bool isNewTag = true; std::string tagstr(jsTAG->valuestring); // Capture the tag sets to check. We're racing against the timer handler which will // rotate the grandparent to great-grandparent, parent to grandparent, current to // parent, and an empty set into current. When we capture current/parent/grand during // the rotation operation, we might wind up with empty/current/parent or // current/current/parent or current/parent/parent, but since we're checking right // on the "rotation" time, the relevant time window really does encompass just current // and parent at the instant we start looking. We might wind up checking a tag set // twice, but we won't segfault and we won't miss checking a relevant tag set. auto currentSet = tagsCurr; auto parentSet = tagsOld; auto grandparentSet = tagsOldest; if (currentSet->end() != currentSet->find(tagstr) || parentSet->end() != parentSet->find(tagstr) || grandparentSet->end() != grandparentSet->find(tagstr)) { isNewTag = false; trace.NOTE("Tag is a duplicate"); } else { // Yes, I really mean tagsCurr. If rotation happened between the time this // thread grabbed the set pointers and now, currentSet points to tagsOld, so // putting this tag into currentSet might leave the tag active for just a hair // less than the guaranteed interval. Better too long than too short. tagsCurr->insert(tagstr); trace.NOTE("Tag is new"); } return isNewTag; } void Listener::RotateTagSets() { Trace trace(Trace::EventIngest, "Listener::RotateTagSets"); if (trace.IsActive()) { std::ostringstream msg; msg << Name() << " Tagset sizes: Curr=" << tagsCurr->size() << "; Old=" << tagsOld->size(); msg << "; Oldest=" << tagsOldest->size(); trace.NOTE(msg.str()); } tagsAgedOut = tagsOldest; tagsOldest = tagsOld; tagsOld = tagsCurr; tagsCurr = new tag_set(); } void Listener::ScrubTagSets() { Trace trace(Trace::EventIngest, "Listener::ScrubTagSets"); if (tagsAgedOut) { if (trace.IsActive()) { std::ostringstream msg; msg << Name() << " releasing " << tagsAgedOut->size() << " tags"; trace.NOTE(msg.str()); } delete tagsAgedOut; tagsAgedOut = 0; } } void Listener::timerhandler(std::shared_ptr listener, Listener::TimerTask job) { Trace trace(Trace::EventIngest, "Listener::timerhandler"); if (listener->IsFinished()) { // Do nothing; especially, do not reschedule the timer. The deadline_timer code // will allow its copy of the shared_ptr for this instance to go out of scope, // triggering a safe delete of the Listener class instance. If we were cancelled, // we'll want to do exactly the same thing, and Listener::handler is careful to set // _socketClosed before it tries to cancel the timer. As a result, there's no need // to check to see if we're being cancelled or not; if _socketClosed is set, // just return. trace.NOTE(listener->Name() + " IsFinished is true"); return; } switch (job) { case TimerTask::rotate: listener->RotateTagSets(); listener->Timer().expires_from_now(boost::posix_time::seconds(15)); listener->Timer().async_wait(boost::bind(&Listener::timerhandler, listener, TimerTask::cleanup)); break; case TimerTask::cleanup: listener->ScrubTagSets(); listener->Timer().expires_from_now(boost::posix_time::seconds(checkpointSeconds - 15)); listener->Timer().async_wait(boost::bind(&Listener::timerhandler, listener, TimerTask::rotate)); break; default: Logger::LogError("Listener::timerhandler saw unexpected state " + std::to_string(job)); listener->Timer().expires_from_now(boost::posix_time::seconds(checkpointSeconds)); listener->Timer().async_wait(boost::bind(&Listener::timerhandler, listener, TimerTask::rotate)); break; } } void Listener::DumpBuffer(std::ostream& os, const char* start, const char* end) { size_t n = end - start + 1; if (n < 1024*1024) { os << "Buffer contents [" << std::string(start, n) << "]"; } else { os << "Partial buffer contents [" << std::string(start, 1024*1024) << "]"; } } // vim: set ai sw=8 expandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/Listener.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _LISTENER_HH_ #define _LISTENER_HH_ #include "Logger.hh" //#include "PoolMgmt.hh" #include #include #include #include #include #include #include #include extern "C" { #include "cJSON.h" } // Instances of Listener (and derived classes) *must* be referenced via shared_ptr. The thread startproc // and timerhandler functions race to be the last one with a pointer to the instance once ProcessLoop() // returns, and it's not even the timerhandler that holds the last pointer; it's boost::deadline_timer that // is often the last holder. We need to ensure the _timer object remains valid until deadline_timer lets // go of it. class Listener { private: typedef std::unordered_set tag_set; typedef enum { rotate, cleanup } TimerTask; Listener(const Listener&) = delete; // Do not define; copy construction forbidden Listener& operator=(const Listener &) = delete; // Ditto for assignment void Shutdown(); void LogBadJSON(cJSON* event, const std::string&); bool IsNewTag(cJSON* jsTAG); void EchoTag(char* tag); void DumpBuffer(std::ostream& os, const char* start, const char* end); void RotateTagSets(); void ScrubTagSets(); bool TryParseEvent(cJSON* event); bool TryParseEcho(cJSON* event); int clientfd; tag_set *tagsAgedOut; tag_set *tagsOldest; tag_set *tagsOld; tag_set *tagsCurr; static unsigned int checkpointSeconds; boost::asio::deadline_timer _timer; boost::asio::deadline_timer& Timer() { return _timer; } static void timerhandler(std::shared_ptr, TimerTask); bool _finished; std::string _name; protected: const char * ParseBuffer(const char* start, const char* end); int fd() const { return clientfd; } public: Listener(int fd); virtual ~Listener(); virtual void * ProcessLoop() { Logger::LogError("Listener::ProcessLoop() was called"); return 0; } static void * handler(void *); // Thread proc for all listeners bool IsFinished() const { return _finished; } const std::string& Name() const { return _name; } static void setDupeWindow(unsigned long seconds) { checkpointSeconds = seconds / 2; } class exception : public std::exception { public: exception(const std::string & msg) : std::exception(), _what(msg) {} exception(const std::ostringstream &msg) : std::exception(), _what(msg.str()) {} exception(const char * msg) : std::exception(), _what(msg) {} virtual const char * what() const noexcept { return _what.c_str(); } private: std::string _what; }; }; // vim: set ai sw=8: #endif // _LISTENER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/LocalSink.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "LocalSink.hh" #include #include #include #include "CanonicalEntity.hh" #include "MdsdConfig.hh" #include "Utility.hh" #include "RowIndex.hh" #include "Trace.hh" #include "MdsdMetrics.hh" #include "StoreType.hh" #include "SchemaCache.hh" #include "Logger.hh" #include "EventHubUploaderId.hh" #include "EventHubType.hh" #include "EventHubUploaderMgr.hh" // Class statics // // Table of local tables and a mutex to protect it. The map is altered only // while loading configurations, but that can happen in parallel with incoming events. // It may be that the global table is referenced only during config load, in which case // the mutex won't be needed. // // These are on the heap because there's no way to control the order of destruction of // global static objects declared in separate compilation units. The Batch class contains // a pointer to a sink; Batch instances in the global static MdsdConfig::_localBatches // BatchSet all point to LocalSink objects. If the LocalSink::_localTables map were a // global static, it might be destroyed at program-exit before the static _localBatches // was destroyed. In that case, when the Batch destructor deletes its LocalSink, the // LocalSink destructor tries to remove the object from the _localTables map which has // already been destroyed. std::mutex* LocalSink::_ltMutex { nullptr }; std::map* LocalSink::_localTables { nullptr }; void LocalSink::Initialize() { if (_ltMutex == nullptr) { _ltMutex = new std::mutex; _localTables = new std::map; } } LocalSink::LocalSink(const std::string &name) : IMdsSink(StoreType::Type::Local), _name(name), _schemaId(0) { Trace trace(Trace::Local, "LocalSink::Constructor"); std::unique_lock lock(*_ltMutex); auto result = _localTables->insert(std::pair(_name, this)); lock.unlock(); if (!(result.second)) { throw std::invalid_argument("Duplicate local table name"); } } LocalSink::~LocalSink() { Trace trace(Trace::Local, "LocalSink::Destructor"); std::lock_guard lock(*_ltMutex); _localTables->erase(_name); } void LocalSink::AllocateSchemaId() { _schemaId = SchemaCache::Get().GetId(); } LocalSink* LocalSink::Lookup(const std::string& name) { Trace trace(Trace::Local, "LocalSink::Lookup"); trace.NOTE("Looking for LocalSink " + name); std::lock_guard lock(*_ltMutex); auto iter = _localTables->find(name); if (iter == _localTables->end()) { trace.NOTE("Not found"); return nullptr; } else { trace.NOTE("Found it"); return iter->second; } } // Copy the CE before adding it void LocalSink::AddRow(const CanonicalEntity &row, const MdsTime& ) { Trace trace(Trace::Local, "LocalSink::AddRow(CE)"); std::shared_ptr item; try { item.reset(new CanonicalEntity(row)); } catch (const std::exception& ex) { Logger::LogError("Exception copying item to insert into LocalSink " + _name + ": " + ex.what()); return; } AddRow(item); } // This version of AddRow assumes it can share the CE. void LocalSink::AddRow(std::shared_ptr item) { Trace trace(Trace::Local, "LocalSink::AddRow(shared CE)"); size_t nEvents = 0; try { // Add row to event collection, ordered by the PreciseTime() in the item. // If retention period is zero, there are no downstream consumers; don't even bother // adding the item to the list. This behavior should change when local sinks are persisted; // the item should be written to the disk. If some fraction of a local sink is retained in // memory (as a performance optimization), that should not happen if RetentionPeriod() == 0 if (RetentionPeriod()) { std::lock_guard lock(_mutex); _events.emplace_hint(_events.end(), item->PreciseTime(), item); nEvents = _events.size(); } if (!_ehpubMonikers.empty() && CanonicalEntity::SourceType::Ingested == item->GetSourceType()) { SendToEventPub(item); } } catch (const std::exception& ex) { Logger::LogError("Exception adding item to LocalSink " + _name + ": " + ex.what()); return; } TRACEINFO(trace, "LocalSink " << _name << " now has " << nEvents << " rows"); } // Copy the value (shared_ptr) from the map elements in the range. This increases the // refcount on all the shared pointers; it doesn't actually copy the CanonicalEntity objects. // *** Must be called with _mutex already held *** LocalSink::vector_type LocalSink::ExtractRange(LocalSink::iterator start, LocalSink::iterator end) { LocalSink::vector_type extract; typedef LocalSink::iterator::value_type value_type; if (start != end) { try { auto count = std::distance(start, end); extract.reserve(count); std::for_each(start, end, [&extract](value_type& val){extract.push_back(val.second);}); } catch (const std::exception& ex) { Logger::LogError("Exception in ExtractRange on " + _name + ": " + ex.what()); } } return extract; } void LocalSink::Flush() { Trace trace(Trace::Local, "LocalSink::Flush"); // The instance knows the longest timespan we'll ever be asked for (gap between // Foreach()'s begin and delta parameters. Just call Flush(now - span). // We actually double the span for safety's sake. Flush(MdsTime::Now() - RetentionPeriod() - RetentionPeriod()); } void LocalSink::Flush(const MdsTime& when) { Trace trace(Trace::Local, "LocalSink::Flush(when)"); TRACEINFO(trace, "Flushing items older than " << when << " from LocalSink " << _name); LocalSink::vector_type scrubList; try { std::lock_guard lock(_mutex); iterator rangeEnd = _events.lower_bound(when); if (rangeEnd == _events.begin()) { TRACEINFO(trace, "Nothing to remove from LocalSink " << _name); return; } scrubList = ExtractRange(_events.begin(), rangeEnd); // Erase all the entries from the multimap (won't destroy the CEs) and release the lock TRACEINFO(trace, "Removing " << scrubList.size() << " items from " << _name); _events.erase(_events.begin(), rangeEnd); } catch (const std::exception& ex) { Logger::LogError("Exception while removing range from " + _name + ": " + ex.what()); } // Now we can delete these without blocking everyone else waiting on the sink. It is very // likely the shared_ptrs in this list have a refcount of 1 and will thus the CEs will be destructed. // By explicitly clearing the scrubList, we can determine how much real time is required // to destroy all those objects (time between this trace message and the "Leaving" message). TRACEINFO(trace, "Destroying " << scrubList.size() << " items removed from " << _name); scrubList.clear(); } // Extract each event in the [begin, begin+delta) range, then invoke the function on each extracted event. // Release the shared ptr for the extracted events as we go, amortizing heap operations over time. A large // extract may be the last holder of a reference to a CE, so releasing as-we-go could make memory available sooner. void LocalSink::Foreach(const MdsTime &begin, const MdsTime &delta, const std::function& fn) { Trace trace(Trace::Local, "LocalSink::Foreach"); TRACEINFO(trace, "begin at " << begin << ", delta " << delta); LocalSink::vector_type matchedEvents; try { std::lock_guard lock(_mutex); matchedEvents = ExtractRange(_events.lower_bound(begin), _events.lower_bound(begin + delta)); } catch (const std::exception& ex) { Logger::LogError("Exception while extracting range from " + _name + ": " + ex.what()); return; } TRACEINFO(trace, "Extracted " << matchedEvents.size() << " events from " << _name); for (auto& eventPtr : matchedEvents) { fn(*eventPtr); eventPtr.reset(); // Done with this item; if we're the last user, let it go } } void LocalSink::SetEventPublishInfo( const std::unordered_set & monikers, std::string eventDuration, std::string tenant, std::string role, std::string roleInstance ) { if (monikers.empty()) { throw std::invalid_argument("SetEventPublishInfo(): moniker cannot be empty."); } _ehpubMonikers = monikers; _eventDuration = std::move(eventDuration); _tenant = std::move(tenant); _role = std::move(role); _roleInstance = std::move(roleInstance); } void LocalSink::SendToEventPub(std::shared_ptr item) { Trace trace(Trace::Local, "LocalSink::SendToEventPub"); if (!item) { throw std::invalid_argument("LocalSink::SendToEventPub(): CanonicalEntity cannot be nullptr"); } auto jsonData = item->GetJsonRow(_eventDuration, _tenant, _role, _roleInstance); if (jsonData.empty()) { throw std::runtime_error("LocalSink::SendToEventPub(): failed to get data to publish."); } mdsd::EventDataT ehdata; ehdata.SetData(jsonData); auto ehtype = mdsd::EventHubType::Publish; for (const auto & moniker : _ehpubMonikers) { mdsd::EventHubUploaderMgr::GetInstance().AddMessageToUpload( mdsd::EventHubUploaderId(ehtype, moniker, _name), std::move(ehdata)); TRACEINFO(trace, "LocalSink::SendToEventPub: moniker=" << moniker << "; sinkName=" << _name); } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/LocalSink.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _LOCALSINK_HH_ #define _LOCALSINK_HH_ #include "IMdsSink.hh" #include #include #include #include #include #include #include "MdsTime.hh" #include "MdsEntityName.hh" #include "CanonicalEntity.hh" #include "SchemaCache.hh" class LocalSink : public IMdsSink { public: typedef std::multimap> map_type; typedef std::vector> vector_type; typedef map_type::iterator iterator; LocalSink(const std::string&); virtual ~LocalSink(); virtual bool IsLocal() const { return true; } virtual void AddRow(const CanonicalEntity&, const MdsTime&); virtual void Flush(); // An ingested event goes to precisely one LocalSink; this method // lets us avoid copying the CE upon ingest void AddRow(std::shared_ptr); void Flush(const MdsTime &when); void Foreach(const MdsTime &when, const MdsTime &delta, const std::function&); void AllocateSchemaId(); SchemaCache::IdType SchemaId() { return _schemaId; } static LocalSink * Lookup(const std::string& name); static void Initialize(); void SetEventPublishInfo(const std::unordered_set & monikers, std::string eventDuration, std::string tenant, std::string role, std::string roleInstance); private: vector_type ExtractRange(iterator start, iterator end); void SendToEventPub(std::shared_ptr item); map_type _events; const std::string _name; // Applies only to local sinks which directly receive json external data; derived // local tables will have a 0 _schemaId, and so will sinks that receive BOND and dynamic json external data. SchemaCache::IdType _schemaId; std::mutex _mutex; static std::map* _localTables; static std::mutex* _ltMutex; // event publishing information std::unordered_set _ehpubMonikers; std::string _eventDuration; std::string _tenant; std::string _role; std::string _roleInstance; }; #endif // _LOCALSINK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsBlobOutputter.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _MDSBLOBOUTPUTTER_HH #define _MDSBLOBOUTPUTTER_HH #include #include "Crypto.hh" #include "Trace.hh" #include "Logger.hh" #include "Utility.hh" #include #include #include class MdsBlobOutputter { public: MdsBlobOutputter(size_t maxbytes) : _buffer(0), _end(0), _current(0) { if (maxbytes) { _current = _buffer = new unsigned char [maxbytes]; _end = _buffer + maxbytes; } } ~MdsBlobOutputter() { if (_buffer) delete [] _buffer; } size_t size() const { return (_buffer) ? (_current - _buffer) : 0; } void clear() { if (_buffer) { delete [] _buffer; _buffer = nullptr; } } unsigned char * data() { return _buffer; } template typename std::enable_if::value, void>::type Write(const T& value) { Trace trace(Trace::BondDetails, "Write"); TRACEINFO(trace, sizeof(T) << " bytes"); if (_current + sizeof(T) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } * reinterpret_cast(_current) = value; _current += sizeof(T); } void Write(const std::string& value) { Trace trace(Trace::BondDetails, "Write"); size_t bytecount = value.size(); size_t totalbytes = bytecount + sizeof(uint32_t); TRACEINFO(trace, value.size() << " characters, " << bytecount << " bytes (" << totalbytes << " total)"); if ((_current + totalbytes) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } * reinterpret_cast(_current) = bytecount; ::memcpy(_current + sizeof(uint32_t), value.data(), bytecount); _current += totalbytes; } void Write(const std::u16string& value) { Trace trace(Trace::BondDetails, "Write"); size_t bytecount = sizeof(std::u16string::value_type) * value.size(); size_t totalbytes = bytecount + sizeof(uint32_t); TRACEINFO(trace, value.size() << " characters, " << bytecount << " bytes (" << totalbytes << " total)"); if ((_current + totalbytes) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } * reinterpret_cast(_current) = bytecount; ::memcpy(_current + sizeof(uint32_t), value.data(), bytecount); _current += totalbytes; } void WriteShort(const std::u16string& value) { Trace trace(Trace::BondDetails, "WriteShort"); size_t bytecount = sizeof(std::u16string::value_type) * value.size(); size_t totalbytes = bytecount + sizeof(uint16_t); TRACEINFO(trace, value.size() << " characters, " << bytecount << " bytes (" << totalbytes << " total)"); if ((_current + totalbytes) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } * reinterpret_cast(_current) = static_cast(bytecount); ::memcpy(_current + sizeof(uint16_t), value.data(), bytecount); _current += totalbytes; } void Write(const Crypto::MD5Hash& value) { Trace trace(Trace::BondDetails, "Write"); size_t len = Crypto::MD5Hash::DIGEST_LENGTH; TRACEINFO(trace, len << " bytes"); if (_current + len > _end) { throw std::overflow_error("Bond blob buffer overflow"); } ::memcpy(_current, value.GetBuffer(), len); _current += len; } void Write(const char * array, size_t len) { Trace trace(Trace::BondDetails, "Write"); if (len && !array) { throw std::invalid_argument("Attempt to write non-zero length char* array from NULL pointer"); } if (!len) { Logger::LogWarn("Blob writer asked to write zero-length char array"); return; } TRACEINFO(trace, len << " bytes to be written"); if ((_current + len) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } ::memcpy(_current, array, len); _current += len; } void Write(const unsigned char * array, size_t len) { Trace trace(Trace::BondDetails, "Write"); if (len && !array) { throw std::invalid_argument("Attempt to write non-zero length unsigned char* array from NULL pointer"); } if (!len) { Logger::LogWarn("Blob writer asked to write zero-length unsigned char array"); return; } TRACEINFO(trace, len << " bytes to be written"); if ((_current + len) > _end) { throw std::overflow_error("Bond blob buffer overflow"); } ::memcpy(_current, array, len); _current += len; } void WriteSuffix() { Write(0xdeadc0dedeadc0de); } private: unsigned char* _buffer; unsigned char* _end; unsigned char* _current; void dumpstate(std::ostream& strm) { strm << "_buffer=" << static_cast(_buffer); strm << " _current=" << static_cast(_current); strm << " _end=" << static_cast(_end); } }; #endif // _MDSBLOBOUTPUTTER_HH // vim: se expandtab sw=4 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsEntityName.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsEntityName.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "Crypto.hh" #include "Utility.hh" #include "Logger.hh" #include "Trace.hh" #include using std::string; // MdsEntityName for SchemasTable in the account identified by these creds MdsEntityName::MdsEntityName(const MdsdConfig *config, const Credentials *creds) : _creds(creds) { Trace trace(Trace::EntityName, "MdsEntityName constructor for SchemasTable"); if (!config) { throw std::invalid_argument("Internal error: null config ptr"); } else if (!creds) { throw std::invalid_argument("Internal error: null credentials"); } _storeType = StoreType::XTable; _physTableName = _basename = "SchemasTable"; _isConstant = true; _isSchemasTable = true; } // Constructor for arbitrary table in some store (local or remote) accessed via a specific moniker. MdsEntityName::MdsEntityName(const std::string &eventName, bool noPerNDay, const MdsdConfig *config, const std::string &acct, StoreType::Type sinkType, bool isFullName) : _basename(eventName), _isConstant(true), _isSchemasTable(false), _storeType(sinkType), _creds(nullptr), _physTableName(eventName), _eventName(eventName), _eventVersion(config->EventVersion()) { Trace trace(Trace::EntityName, "MdsEntityName constructor"); if (eventName.empty()) { throw std::invalid_argument("eventName must not be empty"); } auto maxNameLength = StoreType::max_name_length(_storeType); if (sinkType == StoreType::Type::Local || sinkType == StoreType::Type::File) { // Local table names never get encoded/shortened. Also, they need no credentials and no MdsdConfig if (_basename.length() > maxNameLength) { std::ostringstream msg; msg << "Event name \"" << _basename << "\" is too long for requested storeType (max " << maxNameLength << " bytes)"; throw std::invalid_argument(msg.str()); } if (trace.IsActive()) { std::ostringstream msg; msg << "Local/File EventName \"" << eventName << "\" yields basename \"" << _basename<< "\" and _isConstant=" << _isConstant; trace.NOTE(msg.str()); } return; } if (!config) { throw std::invalid_argument("Internal error: null config ptr"); } if (acct.empty()) { if (! (_creds = config->GetDefaultCredentials())) { throw std::invalid_argument("No default credentials were defined"); } } else { if (! (_creds = config->GetCredentials(acct))) { throw std::invalid_argument("No definition found for account moniker " + acct); } } // The access credentials can influence how the actual name of the entity is computed, so // we have to look inside. if (isFullName && noPerNDay) { _isConstant = true; trace.NOTE("Marked as isFullName without NDay suffix"); } else if (_creds->accessAnyTable()) { std::ostringstream augmentedName; if (isFullName) { augmentedName << eventName; trace.NOTE("Marked as isFullName and gets NDay suffix"); } else { augmentedName << config->Namespace() << eventName << "Ver" << config->EventVersion() << "v0"; } _basename = _physTableName = augmentedName.str(); _isConstant = noPerNDay; // This name might vary // The basename plus perNDay suffix (if any) must fit within the maximum entity name size // for MDS. If it doesn't, replace the basename with "T" followed by the MD5 hash of the // basename (without perNDay suffix), which is always short enough. // See Windows MA source NetTransport.cpp:GetNDayEventName() size_t limit = maxNameLength - (_isConstant?0:8); if (_basename.size() > limit) { trace.NOTE("Basename " + _basename + " too long; using MD5 hash"); _basename = "T" + Crypto::MD5HashString(_basename).to_string(); } } else if (auto SAScreds = dynamic_cast(_creds)) { if (!isFullName) { std::ostringstream augmentedName; augmentedName << config->Namespace() << eventName << "Ver" << config->EventVersion() << "v0"; _physTableName = augmentedName.str(); } // SAS (non-account SAS) includes the tablename; extract it from there. Even if isFullName is set, we have to try this std::map qry; MdsdUtil::ParseQueryString(SAScreds->Token(), qry); auto item = qry.find("tn"); if (item != qry.end()) { _basename = item->second; } else if (!SAScreds->IsAccountSas()) { // We'll just use what we were given; it'll probably fail later, too. Logger::LogError("Table SAS lacks [tn=]: " + SAScreds->Token()); } } if (trace.IsActive()) { std::ostringstream msg; msg << "EventName \"" << eventName << "\" yields basename \"" << _basename<< "\", physTableName \""; msg << _physTableName << "\", and _isConstant=" << _isConstant; trace.NOTE(msg.str()); } } std::string MdsEntityName::Name() const { Trace trace(Trace::EntityName, "MdsEntityName::Name"); if (_isConstant) { trace.NOTE("Using " + _basename); return _basename; } std::string fullname = _basename + MdsdUtil::GetTenDaySuffix(); trace.NOTE("Computed table name " + fullname); return fullname; } std::ostream& operator<<(std::ostream &str, const MdsEntityName &target) { switch(target._storeType) { case StoreType::None: str << "[None]"; break; case StoreType::XTable: str << "[XTable]"; break; case StoreType::Local: str << "[Local]"; break; case StoreType::File: str << "[File]"; break; default: str << "[unknown]"; break; } str << target._basename; if (! target._isConstant) { str << "*"; } return str; } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsEntityName.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSENTITYNAME_HH_ #define _MDSENTITYNAME_HH_ #include #include #include "MdsTime.hh" #include "StoreType.hh" class MdsdConfig; class Credentials; // Repository of metadata about the MDS target we're writing. Might be a server-side XTable // or Bond blob; might be a local table. This object knows the name of that thing, the kind of // thing it is, and has a pointer to the access credentials (if any) needed to talk to it. class MdsEntityName { friend std::ostream& operator<<(std::ostream &str, const MdsEntityName &target); public: // SchemasTable accessible with these creds MdsEntityName(const MdsdConfig *config, const Credentials *creds); // This arbitrary MDS entity (table, blob, whatever) MdsEntityName(const std::string &eventName, bool noPerNDay, const MdsdConfig *config, const std::string &acct, StoreType::Type sinkType, bool isFullName=false); // Require autogenerated move-assignment and copy/move constructor MdsEntityName& operator=(MdsEntityName &&orig) = default; MdsEntityName(const MdsEntityName&) = default; MdsEntityName(MdsEntityName&&) = default; // Compute the XStore table name to be written to right now, at this instant. std::string Name() const; // The XStore table "family" name, i.e. without 10day suffix. std::string Basename() const { return _basename; } // The full-length table name, without 10day suffix, as it would appear in various // MDS tools. This can be longer than the 64-char max for XStore table names. std::string PhysicalTableName() const { return _physTableName; } // Get the original Eventname std::string EventName() const { return _eventName; } /// Get the original EventVersion int EventVersion() const { return _eventVersion; } // True if the table name never changes (e.g. no 10day suffix). bool IsConstant() const { return _isConstant; } bool IsSchemasTable() const { return _isSchemasTable; } StoreType::Type GetStoreType() const { return _storeType; } const Credentials* GetCredentials() const { return _creds; } private: // The tablename, with version suffix but without the 10-day suffix, as used when writing // to XStore. If the name is "too long", this is the MD5-hashed name. std::string _basename; bool _isConstant; bool _isSchemasTable; StoreType::Type _storeType; const Credentials* _creds; // This form of the name is used in the PhysicalTableName column of SchemasTable. It // is identical to _basename except // when the name is too long, _basename is hashed, _physTableName is the unhashed, // very long form of the name. // Despite what the column (and this variable) are called, // this name is not the name of the actual physical table. std::string _physTableName; std::string _eventName; // save the original event name int _eventVersion; // save the original event version // const size_t MaxEntityNameLength = 63; }; std::ostream& operator<<(std::ostream &str, const MdsEntityName &target); #endif // _MDSENTITYNAME_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsSchemaMetadata.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsSchemaMetadata.hh" #include "Crypto.hh" #include "IdentityColumns.hh" #include "TableSchema.hh" #include "Trace.hh" #include "MdsEntityName.hh" #include "CanonicalEntity.hh" #include #include #include #include #include using std::string; using std::vector; std::map MdsSchemaMetadata::_cache; std::mutex MdsSchemaMetadata::_mutex; #define STRINGPAIR(a,b) std::make_pair(string(a),string(b)) typedef std::pair coldata_t; const std::unordered_set MdsSchemaMetadata::MetadataColumns { "TIMESTAMP", "PreciseTimeStamp", "PartitionKey", "RowKey", "N", "RowIndex" }; // Given a set of destination metadata and a CanonicalEntity, build the metadata MDS needs // to interpret the destination object (table, Bond blob, etc.) MdsSchemaMetadata* MdsSchemaMetadata::GetOrMake(const MdsEntityName &target, const CanonicalEntity* ce) { if (!ce) { return nullptr; } vector unsortedSchema; unsortedSchema.reserve(ce->size() + 6); // First, the timestamps... unsortedSchema.push_back(STRINGPAIR("TIMESTAMP", "mt:utc")); unsortedSchema.push_back(STRINGPAIR("PreciseTimeStamp", "mt:utc")); // Next, the data and identity columns (the identity columns are expected to have // already been added by this point). Ignore any of the "special" columns. for (const auto & col : *ce) { if (! MetadataColumns.count(col.first)) { unsortedSchema.push_back(STRINGPAIR(col.first, col.second->TypeToString())); } } // XTable targets get some extra metadata if (target.GetStoreType() == StoreType::Type::XTable) { unsortedSchema.push_back(STRINGPAIR("PartitionKey", "mt:wstr")); unsortedSchema.push_back(STRINGPAIR("RowKey", "mt:wstr")); unsortedSchema.push_back(STRINGPAIR("N", "mt:wstr")); unsortedSchema.push_back(STRINGPAIR("RowIndex", "mt:wstr")); } return GetOrMake(unsortedSchema); } // Given a vector of pairs, // build the MDS table metadata (XML-format schema and MD5 hash of canonicalized schema). MdsSchemaMetadata* MdsSchemaMetadata::GetOrMake(vector& schema) { string elements; for (auto it = schema.cbegin(); it != schema.cend(); ++it) { elements += "first + "\" type=\"" + it->second + "\">"; } std::sort(schema.begin(), schema.end(), [](coldata_t left, coldata_t right) -> bool { return (left.first.compare(right.first) < 0); } ); int columnCount = schema.size(); string schemaForMD5; for (int i = 0; i < columnCount; ++i) { schemaForMD5 += schema[i].first + "," + schema[i].second; if (i < (columnCount-1)) { schemaForMD5 += ","; } } string md5 = Crypto::MD5HashString(schemaForMD5).to_string(); std::lock_guard lock(_mutex); // Take lock on _cache; lock is released at function return auto it = _cache.find(schemaForMD5); if (it != _cache.end()) { return it->second; } // Lock contention is rare, hits are common, and this string can // get moderately large. Deferring assembly until needed should save time in the long run. string xmldata = ""; xmldata += elements; xmldata += ""; _cache[schemaForMD5] = new MdsSchemaMetadata(move(xmldata), move(md5), columnCount); return _cache[schemaForMD5]; // Be sure to return the address of the object in the cache } #ifdef DOING_MEMCHECK // Remove everything from the cache. void MdsSchemaMetadata::ClearCache() { Trace trace(Trace::ConfigLoad, "MdsSchemaMetadata::ClearCache"); std::lock_guard lock(_mutex); size_t count = 0; for (auto entry : _cache) { delete entry.second; count++; } _cache.clear(); trace.NOTE("Deleted " + std::to_string(count) + " MdsSchemaMetadata objects from cache"); } #endif // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsSchemaMetadata.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSSCHEMAMETADATA_HH_ #define _MDSSCHEMAMETADATA_HH_ #include #include #include #include #include #include #include "Crypto.hh" #include "IdentityColumns.hh" #include "MdsEntityName.hh" class TableSchema; class CanonicalEntity; class MdsSchemaMetadata { public: typedef std::pair coldata_t; static const std::unordered_set MetadataColumns; // Check cache for schema; if it exists, return pointer. Otherwise, create it, add it to cache, and return pointer. static MdsSchemaMetadata* GetOrMake(const MdsEntityName &target, const CanonicalEntity* ce); const std::string& GetXML() const { return _xmldata; } const std::string& GetMD5() const { return _md5; } size_t GetSize() const { return _size; } #ifdef DOING_MEMCHECK static void ClearCache(); #endif private: const std::string _xmldata; // The MDS SchemasTable "Schema" column representation const std::string _md5; // The MD5 checksum of the canonicalized schema const size_t _size; // The number of columns, including identity columns and everything else MdsSchemaMetadata(std::string&& x, std::string&& m, size_t s) : _xmldata(x), _md5(m), _size(s) {} MdsSchemaMetadata() = delete; // No default constructor static MdsSchemaMetadata* GetOrMake(std::vector&); // Maps from canonical name/type list to the object static std::map _cache; static std::mutex _mutex; // Ensures access to the cache is serialized }; #endif // _MDSSCHEMAMETADATA_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/MdsValue.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsValue.hh" #include #include #include #include #include #include #include #include #include #include "Utility.hh" #include "cpprest/json.h" // Copy constructor MdsValue::MdsValue(const MdsValue& src) : type(src.type) { switch(type) { case mt_bool: bval = src.bval; break; case mt_int32: lval = src.lval; break; case mt_int64: llval = src.llval; break; case mt_float64: dval = src.dval; break; case mt_wstr: strval = new std::string(*(src.strval)); break; case mt_utc: datetimeval = src.datetimeval; break; default: throw std::logic_error("Attempt to copy MdsValue of unknown type"); } } // Constructor for MdsTime MdsValue::MdsValue(const MdsTime& val) { type = mt_utc; datetimeval = val.to_pplx_datetime(); } // Constructor for mi::Datetime MdsValue::MdsValue(const mi::Datetime& x) { *this = MdsValue(MdsTime(x)); } // Move assignment operator MdsValue& MdsValue::operator=(MdsValue&& src) { type = src.type; switch(type) { case mt_bool: bval = src.bval; break; case mt_int32: lval = src.lval; break; case mt_int64: llval = src.llval; break; case mt_float64: dval = src.dval; break; case mt_wstr: strval = src.strval; src.strval = nullptr; break; case mt_utc: datetimeval = src.datetimeval; break; default: throw std::logic_error("Attempt to move-assign MdsValue of unknown type"); } return *this; } MdsValue* MdsValue::time_t_to_utc(cJSON* src) { if (src->type != cJSON_Number) return 0; if (src->valueint > LONG_MAX) return 0; return new MdsValue(MdsTime(src->valueint, 0)); } MdsValue* MdsValue::double_time_t_to_utc(cJSON* src) { if (src->type != cJSON_Number) return 0; if (src->valuedouble > double(LONG_MAX) || src->valuedouble < 0.) return 0; long sec = int(floor(src->valuedouble)); long fraction = int(floor(1000000. * (src->valuedouble - floor(src->valuedouble)))); return MdsValue::sec_usec_to_utc(sec, fraction); } MdsValue* MdsValue::rfc3339_to_utc(cJSON* src) { if (src->type != cJSON_String) return 0; size_t n = strlen(src->valuestring); if (n < 19) return 0; // Minimum legal length of an RFC 3339 datetime string long tv_sec = 0, tv_usec = 0; if (!MdsdUtil::TimeValFromIso8601Restricted(src->valuestring, tv_sec, tv_usec)) return 0; return MdsValue::sec_usec_to_utc(tv_sec, tv_usec); } void MdsValue::scale(double factor) { switch(type) { case mt_bool: case mt_wstr: case mt_utc: default: break; case mt_int32: dval = factor * ((double)lval); type = mt_float64; break; case mt_int64: dval = factor * ((double)llval); type = mt_float64; break; case mt_float64: dval = factor * dval; break; } } // The OMI conversions are mostly mechanical, but templatizing them is pretty ugly due to the discriminated // unions in the MI_Value and MdsValue objects. It's // easy enough to use a macro to generate the common case: // // case MI_BOOLEAN: // type = mt_bool; // bval = (bool) value.boolean; // break; // case MITYPE: // type = MTTYPE; // MEMBER = (CTYPE)(value.UNIONARM); #define CVTUNARY(MITYPE, MTTYPE, MEMBER, CTYPE, UNIONARM) case MITYPE: type = MTTYPE; MEMBER = (CTYPE)(value.UNIONARM); break; // Arrays are a bit easier via macro; the MTTYPE, MEMBER, and CTYPE always correspond to strings. template static std::string * OMIarray2string(ARRTYPE arm) { std::ostringstream result; for (MI_Uint32 idx = 0; idx < arm.size; idx++) { auto val = arm.data[idx]; if (idx) { result << ", "; } result << val; } return new std::string(result.str()); } #define CVTARRAY(MITYPE, TYPE, UNIONARM) case MITYPE: type=mt_wstr; strval=OMIarray2string(value.UNIONARM); break; // And there are some exceptions to the pattern that need to be handled explicitly. MdsValue::MdsValue(const MI_Value& value, MI_Type fieldtype) { switch(fieldtype) { CVTUNARY(MI_BOOLEAN, mt_bool, bval, bool, boolean) CVTUNARY(MI_SINT8, mt_int32, lval, long, sint8) CVTUNARY(MI_UINT8, mt_int32, lval, long, uint8) CVTUNARY(MI_SINT16, mt_int32, lval, long, sint16) CVTUNARY(MI_UINT16, mt_int32, lval, long, uint16) CVTUNARY(MI_SINT32, mt_int32, lval, long, sint32) CVTUNARY(MI_UINT32, mt_int64, llval, long long, uint32) CVTUNARY(MI_SINT64, mt_int64, llval, long long, sint64) CVTUNARY(MI_UINT64, mt_int64, llval, long long, uint64) CVTUNARY(MI_REAL32, mt_float64, dval, double, real32) CVTUNARY(MI_REAL64, mt_float64, dval, double, real64) CVTUNARY(MI_CHAR16, mt_int32, lval, long, char16) CVTARRAY(MI_BOOLEANA, MI_BooleanA, booleana) CVTARRAY(MI_SINT8A, MI_Sint8A, sint8a) CVTARRAY(MI_UINT8A, MI_Uint8A, uint8a) CVTARRAY(MI_SINT16A, MI_Sint16A, sint16a) CVTARRAY(MI_UINT16A, MI_Uint16A, uint16a) CVTARRAY(MI_SINT32A, MI_Sint32A, sint32a) CVTARRAY(MI_UINT32A, MI_Uint32A, uint32a) CVTARRAY(MI_SINT64A, MI_Sint64A, sint64a) CVTARRAY(MI_UINT64A, MI_Uint64A, uint64a) CVTARRAY(MI_REAL32A, MI_Real32A, real32a) CVTARRAY(MI_REAL64A, MI_Real64A, real64a) CVTARRAY(MI_CHAR16A, MI_Char16A, char16a) case MI_DATETIME: *this = MdsValue(MdsTime(value.datetime)); break; case MI_STRING: type = mt_wstr; strval = new std::string(value.string); break; case MI_DATETIMEA: { type = mt_wstr; std::ostringstream result; for (MI_Uint32 idx = 0; idx < value.datetimea.size; idx++) { if (idx) { result << ", "; } result << MdsTime(value.datetimea.data[idx]); } strval = new std::string(result.str()); break; } case MI_STRINGA: { type = mt_wstr; std::ostringstream result; for (MI_Uint32 idx = 0; idx < value.stringa.size; idx++) { if (idx) { result << ", "; } result << std::string(value.stringa.data[idx]); } strval = new std::string(result.str()); break; } case MI_INSTANCE: case MI_REFERENCE: case MI_INSTANCEA: case MI_REFERENCEA: throw std::runtime_error("MdsValue asked to convert instance/reference"); default: throw std::runtime_error("MdsValue asked to convert unknown MI_Type"); } } #if 0 std::string MdsValue::omi_time_to_string(const mi::Datetime& x) { MI_Uint32 y,mon,d,h,min,s,us; MI_Sint32 utc; x.Get(y,mon,d,h,min,s,us,utc); struct tm t; t.tm_year = y-1900; t.tm_mon = mon-1; t.tm_mday = d; t.tm_hour = h; t.tm_min = min; t.tm_sec = s; t.tm_isdst = -1; // let mktime() to decide daylight saving adjustment time_t time1 = mktime(&t); long sec = (long)(time1 + 60 * utc); long usec = (long)us; return sec_usec_to_string(sec, usec); } #endif std::string MdsValue::ToString() const { std::ostringstream s; s << *this; return s.str(); } double MdsValue::ToDouble() const { switch(type) { case mt_int32: return (double) lval; case mt_int64: return (double) llval; case mt_float64: return dval; case mt_wstr: try { return boost::lexical_cast(*strval); } catch(const boost::bad_lexical_cast &) { throw std::domain_error("Value is a string which is not a valid floating-point number"); } case mt_utc: case mt_bool: default: throw std::domain_error("Value is not a type which can be converted to float"); } } std::ostream& operator<<(std::ostream& os, const MdsValue& mv) { switch(mv.type) { case MdsValue::MdsType::mt_bool: if (mv.bval) { os << "true"; } else { os << "false"; } break; case MdsValue::MdsType::mt_int32: os << "(int32)" << mv.lval; break; case MdsValue::MdsType::mt_int64: os << "(int64)" << mv.llval; break; case MdsValue::MdsType::mt_float64: os << "(float64)" << mv.dval; break; case MdsValue::MdsType::mt_wstr: os << "(wstr)\"" << *(mv.strval) << "\""; break; case MdsValue::MdsType::mt_utc: os << "(utc)[" << mv.datetimeval.to_string(utility::datetime::ISO_8601) << "]"; break; default: os << "(no type)"; break; } return os; } std::string MdsValue::ToJsonSerializedString() const { web::json::value jsonValue; switch(type) { case MdsValue::MdsType::mt_bool: jsonValue = web::json::value(bval); break; case MdsValue::MdsType::mt_int32: jsonValue = web::json::value(lval); break; case MdsValue::MdsType::mt_int64: jsonValue = web::json::value((int64_t)llval); break; case MdsValue::MdsType::mt_float64: jsonValue = web::json::value(dval); break; case MdsValue::MdsType::mt_wstr: jsonValue = web::json::value(*strval); break; case MdsValue::MdsType::mt_utc: jsonValue = web::json::value(datetimeval.to_string(utility::datetime::ISO_8601)); break; default: throw std::logic_error("Attempt to get JSON value string of unknown type"); } return jsonValue.serialize(); } std::string MdsValue::TypeToString() const { switch(type) { case mt_bool: return "mt:bool"; case mt_int32: return "mt:int32"; case mt_int64: return "mt:int64"; case mt_float64: return "mt:float64"; case mt_wstr: return "mt:wstr"; case mt_utc: return "mt:utc"; } throw std::logic_error("Attempt to convert unknown MDS type to string"); } template std::string * MdsValue::Array2Str(const mi::Array& arr) { auto str = new std::string; for (MI_Uint32 i = 0; i < arr.GetSize(); i++) { T x = arr[i]; if (i) { str->append(", "); } str->append(std::to_string(x)); } return str; } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsValue.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSVALUE_HH_ #define _MDSVALUE_HH_ #include #include #include #include #include #include #include //#include extern "C" { #include "cJSON.h" } #include "MdsTime.hh" class MdsValue { friend std::ostream& operator<<(std::ostream& os, const MdsValue& mv); public: enum MdsType { mt_bool, mt_wstr, mt_float64, mt_int32, mt_int64, mt_utc }; MdsType type; union { bool bval; long lval; long long llval; double dval; utility::datetime datetimeval; const std::string * strval; }; ~MdsValue() { if ((type == mt_wstr) && strval) { delete strval; } } // Type converters. These all return a new MdsValue, copied from the original input, // which the caller will be expected to delete. MdsValue(bool v) { type = mt_bool; bval = v; } MdsValue(long v) { type = mt_int32; lval = v; } MdsValue(long long v) { type = mt_int64; llval = v; } MdsValue(double v) { type = mt_float64; dval = v; } MdsValue(utility::datetime v) { type = mt_utc; datetimeval = v; } MdsValue(const std::string& v) { type = mt_wstr; strval = new std::string(v); } MdsValue(std::string&& v) { type = mt_wstr; strval = new std::string(std::move(v)); } MdsValue(const char * v) { type = mt_wstr; strval = new std::string(v); } MdsValue(const std::ostringstream & str) { type = mt_wstr; strval = new std::string(str.str()); } MdsValue(const MdsTime&); MdsValue(const mi::Datetime&); MdsValue(const MI_Value&, mi::Type); MdsValue(const MdsValue&); // Copy constructor MdsValue(MdsValue&&) = delete; // No move-constructor MdsValue* operator=(const MdsValue&) = delete; // No copy-assignment MdsValue& operator=(MdsValue&&); // Move assignment static MdsValue* time_t_to_utc(cJSON* src); static MdsValue* double_time_t_to_utc(cJSON* src); static MdsValue* sec_usec_to_utc(long sec, long fraction) { return new MdsValue(MdsTime(sec, fraction)); } static MdsValue* rfc3339_to_utc(cJSON* src); // In-place, apply a scale factor to the numeric value. Silently do nothing if the // value is non-numeric. void scale(double); bool IsString() const { return (type == mt_wstr); } bool IsNumeric() const { return (type == mt_float64 || type == mt_int32 || type == mt_int64); } std::string ToString() const; std::string ToJsonSerializedString() const; double ToDouble() const; std::string TypeToString() const; private: MdsValue(); // No void constructor (no "NULL" objects) //static std::string omi_time_to_string(const mi::Datetime& x); //static std::string sec_usec_to_string(long sec, long fraction); template static std::string * Array2Str(const mi::Array&); }; typedef std::function typeconverter_t; std::ostream& operator<<(std::ostream& os, const MdsValue& mv); #endif //_MDSVALUE_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsdConfig.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include "MdsdConfig.hh" #include "CfgCtxRoot.hh" #include "ConfigParser.hh" #include "TableSchema.hh" #include "Subscription.hh" #include "Batch.hh" #include "Credentials.hh" #include "OmiTask.hh" #include "MdsdExtension.hh" #include "ITask.hh" #include "Crypto.hh" #include "Logger.hh" #include "Utility.hh" #include "Trace.hh" #include "EventHubCmd.hh" #include "ConfigUpdateCmd.hh" #include "CmdXmlCommon.hh" #include "EventHubUploaderId.hh" #include "EventHubUploaderMgr.hh" #include "EventHubType.hh" #include "EventPubCfg.hh" #include "MdsdEventCfg.hh" #include "LocalSink.hh" #include "EventType.hh" #include #include #include #include #include #include #include #include #include extern "C" { #include } using std::string; using std::vector; using std::pair; using std::make_pair; // The set of batches that aren't associated with any particular config instance. (Thus the // nullptr initializer.) // // This global static could be associated with the BatchSet class just as easily as the // MdsdConfig class. BatchSet MdsdConfig::_localBatches { nullptr }; MdsdConfig::MdsdConfig(string path, string autokeyConfigPath) : configFilePath(path), _autokeyConfigFilePath(autokeyConfigPath), eventVersion(1), _isUseful(false), _defaultCreds(nullptr), _batchSet(this), _batchFlushTimer(crossplat::threadpool::shared_instance().service()), _agentIdentity(MdsdUtil::GetHostname()), _autoKeyReloadTimer(crossplat::threadpool::shared_instance().service()), _monitoringManagementSeen(false), _hasAutoKey(false), _mdsdEventCfg(std::make_shared()), _eventPubCfg(std::make_shared(_mdsdEventCfg)) { LoadFromConfigFile(path); } void MdsdConfig::Initialize() { Trace trace(Trace::ConfigLoad, "MdsdConfig Initialize"); InitEventHubPub(); FlushBatches(boost::system::error_code()); // Also schedules the next flush } // No autokey support. bool MdsdConfig::LoadAutokey(const boost::system::error_code &e) { Trace trace(Trace::Credentials, "LoadAutoKey"); return false; } // there could be multiple monikers pointing to different storage accounts // pair: first=moniker, second=container SAS std::vector> MdsdConfig::ExtractCmdContainerAutoKeys() { Trace trace(Trace::Credentials, "GetContainerCred"); auto rootContainer = mdsd::CmdXmlCommon::GetRootContainerName(); std::vector> keylist; std::unique_lock lock(_ehMapMutex); for (const auto & iter : _autoKeyMap) { if (rootContainer == iter.first.second) { keylist.push_back(std::make_pair(iter.first.first, iter.second)); } } lock.unlock(); // Get default account to use: either the default credential or the first credential Credentials* cred = _defaultCreds; if (!cred) { cred = credentials.begin()->second; } if (!cred) { TRACEWARN(trace, "No default account is found. No way to do config auto update."); } else { for (const auto & iter : keylist) { auto moniker = iter.first; if (moniker == cred->Moniker()) { cmdContainerSas = iter.second; break; } } if (!cmdContainerSas.empty()) { TRACEINFO(trace, "Found container SAS to download config command blob: " << cmdContainerSas); } } return keylist; } void MdsdConfig::SetMappedMoniker( const EventHubSasInfo_t & ehmap ) { Trace trace(Trace::Credentials, "SetMappedMoniker"); for (const auto & ehEntry : ehmap) { auto & origMoniker = ehEntry.first; auto & itemsMap = ehEntry.second; for (const auto & item : (*itemsMap)) { auto & eventName = item.first; auto & newMoniker = item.second.moniker; _mdsdEventCfg->UpdateMoniker(eventName, origMoniker, newMoniker); } } } void MdsdConfig::LoadEventHubKeys( const std::vector>& keylist ) { Trace trace(Trace::Credentials, "LoadEventHubKeys"); for (const auto & iter : keylist) { auto & moniker = iter.first; // this is what's in mdsd.xml auto & containerSas = iter.second; trace.NOTE("Get EventHub cmd XML for moniker " + moniker + ", containerSas " + containerSas); if(!_mdsdEventCfg->IsEventHubEnabled(moniker)) { trace.NOTE("Moniker " + moniker + " does not have EventHub"); continue; } mdsd::EventHubCmd ehCmd(Namespace(), EventVersion(), containerSas); ehCmd.ProcessCmdXml(); _ehNoticeItemsMap[moniker] = ehCmd.GetNoticeXmlItemsTable(); _ehPubItemsMap[moniker] = ehCmd.GetPublisherXmlItemsTable(); trace.NOTE("Successfully get EventHub cmd XML items (that include SAS keys) for moniker " + moniker); DumpEventPublisherInfo(); } SetMappedMoniker(_ehNoticeItemsMap); SetMappedMoniker(_ehPubItemsMap); } mdsd::EhCmdXmlItems MdsdConfig::GetEventNoticeCmdXmlItems( const std::string & moniker, const std::string & eventName ) { Trace trace(Trace::Credentials, "MdsdConfig::GetEventNoticeCmdXmlItems"); return GetEventHubCmdXmlItems(_ehNoticeItemsMap, moniker, eventName, "EventNotice"); } mdsd::EhCmdXmlItems MdsdConfig::GetEventPublishCmdXmlItems( const std::string & moniker, const std::string & eventName ) { Trace trace(Trace::Credentials, "MdsdConfig::GetEventPublishCmdXmlItems"); return GetEventHubCmdXmlItems(_ehPubItemsMap, moniker, eventName, "EventPublish"); } mdsd::EhCmdXmlItems MdsdConfig::GetEventHubCmdXmlItems( EventHubItemsMap_t& ehmap, const std::string & moniker, const std::string & eventName, const std::string & eventType ) { Trace trace(Trace::Credentials, "MdsdConfig::GetEventHubCmdXmlItems"); std::lock_guard lock(_ehMapMutex); auto iter = ehmap.find(moniker); if (iter == ehmap.end()) { std::ostringstream strm; strm << "Failed to find " << eventType << " SAS & endpoint for moniker=" << moniker; Logger::LogError(strm.str()); return mdsd::EhCmdXmlItems(); } auto xmlItemsMap = iter->second; auto xmlItemsIter = xmlItemsMap->find(eventName); if (xmlItemsIter == xmlItemsMap->end()) { std::ostringstream strm; strm << "Failed to find " << eventType << " SAS & endpoint for event=" << eventName << " (moniker=" << moniker << ")."; Logger::LogError(strm.str()); return mdsd::EhCmdXmlItems(); } TRACEINFO(trace, "Found " << eventType << " (SAS & endpoint) for moniker=" << moniker << ", event=" << eventName << ": " << xmlItemsIter->second); return xmlItemsIter->second; } // Flush the batch set and schedule the next flush. This should be explicitly called // only once; the method is also the timer-pop handler and thus arranges for itself // to be called again. The "cancel()" call is a safety measure in case the method is // called explicitly after loading. void MdsdConfig::FlushBatches(const boost::system::error_code &e) { Trace trace(Trace::Scheduler, "MdsdConfig::FlushBatches"); if (e == boost::asio::error::operation_aborted) { trace.NOTE("Timer cancelled"); } else { _batchSet.FlushIfStale(); _batchFlushTimer.expires_from_now(boost::posix_time::minutes(1)); _batchFlushTimer.async_wait(boost::bind(&MdsdConfig::FlushBatches, this, boost::asio::placeholders::error)); } } // Stop timers that are not related to scheduled tasks: // _batchFlushTimer, _autoKeyReloadTimer void MdsdConfig::StopAllTimers() { Trace trace(Trace::Scheduler, "MdsdConfig::StopAllTimers"); _batchFlushTimer.cancel(); _autoKeyReloadTimer.cancel(); } MdsdConfig::~MdsdConfig() { Trace trace(Trace::ConfigLoad, "MdsdConfig Destructor"); StopAllTimers(); // Configuration load/parse messages size_t count = 0; for (Message* msgptr : messages) { delete msgptr; count++; } trace.NOTE("Removed " + std::to_string(count) + " messages"); messages.clear(); // Configured table schemas (distinct from cached MDS-ready forms of those schemas) count = 0; for (auto iter : schemas) { count++; std::ostringstream msg; msg << "Deleting TableSchema \"" << iter.first << "\" at address " << iter.second; trace.NOTE(msg.str()); delete iter.second; } trace.NOTE("Removed " + std::to_string(count) + " TableSchemas"); schemas.clear(); // Credentials count = 0; for (auto iter : credentials) { count++; std::ostringstream msg; msg << "Deleting Credentials \"" << iter.first << "\" at address " << iter.second; trace.NOTE(msg.str()); delete iter.second; } trace.NOTE("Removed " + std::to_string(count) + " Credentials"); credentials.clear(); // Event sources // Just map source names to TableSchema*, and I've already deleted all the TableSchema objects. trace.NOTE("Clearing all source entries"); sources.clear(); // OmiTask count = 0; for (OmiTask* taskptr : _omiTasks) { count++; std::ostringstream msg; msg << "Deleting OmiTask at address " << taskptr; trace.NOTE(msg.str()); taskptr->Cancel(); delete taskptr; } trace.NOTE("Removed " + std::to_string(count) + " OmiTask object(s)"); _omiTasks.clear(); // ITask count = 0; for (ITask* taskptr : _tasks) { count++; std::ostringstream msg; msg << "Deleting ITask at address " << taskptr; trace.NOTE(msg.str()); taskptr->cancel(); delete taskptr; } trace.NOTE("Removed " + std::to_string(count) + " ITask object(s)"); _tasks.clear(); // Mdsd Extensions count = 0; for (auto & iter : extensions) { count++; std::ostringstream msg; msg << "Deleting MdsdExtension \"" << iter.first << "\" at address " << iter.second; trace.NOTE(msg.str()); delete iter.second; } trace.NOTE("Removed " + std::to_string(count) + " MdsdExtension"); extensions.clear(); // BatchSet() - gets destroyed when this destructor completes // No need to flush; the BatchSet destructor will do that // Autokey map contains no pointers so it gets cleaned up correctly when this destructor completes trace.NOTE("Clearing autokey map"); _autoKeyMap.clear(); _defaultCreds = 0; // Already deleted it while clearing the credentials vector } void MdsdConfig::LoadFromConfigFile(string path) { // Create an appropriate root document context CfgCtxRoot root(this); // Instantiate a new parser with the context ConfigParser parser(&root, this); // Open the path std::ifstream infile(path); if (!infile) { AddMessage(error, "Failed to open config file " + path + " for reading"); return; } // Remember where we were when we were asked to load this file string previousPath(currentPath); long previousLine(currentLine); currentPath = path; currentLine = 0; // Read one line at a time, hand it to the parser's parse_chunk() method string line; while (std::getline(infile, line)) { NextLine(); parser.ParseChunk(line); } if (!infile.eof()) { if (infile.bad()) { AddMessage(error, "Corrupted stream"); } else if (infile.fail()) { AddMessage(error, "IO operation failed"); } else { AddMessage(error, "std::getline returned 0 for unknown reason"); } } currentPath = previousPath; currentLine = previousLine; } void MdsdConfig::AddMessage(severity_t s, const std::string& msg) { Message* newmsg = new MdsdConfig::Message(currentPath, currentLine, s, msg); messages.push_back(newmsg); } bool MdsdConfig::GotMessages(int mask) const { for (const auto& msg : messages) { if (msg->severity & mask) { return true; } } return false; } void MdsdConfig::MessagesToStream(std::ostream& output, int mask) const { for (const auto& msg : messages) { if (msg->severity & mask) { output << msg->filename << "(" << msg->line << ") " << SeverityToString(msg->severity) << ": " << msg->msg << "\n"; } } output << std::flush; } // File scope constants static const std::string _str_fatal = "Fatal", _str_error = "Error", _str_warning = "Warning", _str_info = "Info", _str_unknown = "?" ; const std::string& MdsdConfig::SeverityToString(MdsdConfig::severity_t severity) const { switch (severity) { case MdsdConfig::info: return _str_info; case MdsdConfig::warning: return _str_warning; case MdsdConfig::error: return _str_error; case MdsdConfig::fatal: return _str_fatal; default: return _str_unknown; // Should never happen } } void MdsdConfig::AddSchema(TableSchema* schema) { if (schemas.count(schema->Name())) { AddMessage(error, "Duplicate schema " + schema->Name() + " ignored"); delete schema; } else { schemas[schema->Name()] = schema; } } void MdsdConfig::AddCredentials(Credentials* creds, bool makeDefault) { if (credentials.count(creds->Moniker())) { AddMessage(error, "Duplicate creds " + creds->Moniker() + " ignored"); delete creds; return; } credentials[creds->Moniker()] = creds; if (makeDefault) { if (_defaultCreds) { AddMessage(error, "Cannot make " + creds->Moniker() + " default; another is already set"); } else { _defaultCreds = creds; } } } void MdsdConfig::AddSource(const string& source, const string& schema) { if (schema.length() > 0 && schemas.count(schema) == 0) { AddMessage(error, "Undefined schema " + schema + " referenced"); } else if (sources.count(source)) { AddMessage(error, "Source " + source + " already mapped to a schema; ignored"); } else { sources[source] = schemas[schema]; } } void MdsdConfig::AddDynamicSchemaSource(const string& source) { if (_dynamic_sources.count(source)) { AddMessage(error, "Dynamic Schema Source " + source + " has already been configured; ignored"); } else { _dynamic_sources.insert(source); } } bool MdsdConfig::AddIdentityColumn(const string& colname, const string& colval) { for (auto iter = identityColumns.begin(); iter != identityColumns.end(); ++iter) { if (iter->first == colname) { AddMessage(error, "Ignoring duplicate identity column " + colname); return false; } } identityColumns.push_back(make_pair(colname, colval)); return true; } void MdsdConfig::GetIdentityColumnValues(std::back_insert_iterator > > destination) { std::copy(identityColumns.begin(), identityColumns.end(), destination); } void MdsdConfig::GetIdentityColumnTypes(std::back_insert_iterator > > destination) { for (auto iter = identityColumns.begin(); iter != identityColumns.end(); ++iter) { destination = make_pair(iter->first, "mt:wstr"); } } void MdsdConfig::GetIdentityValues(std::string & tenant, std::string& role, std::string& roleInstance) { ident_vect_t identityColumns; GetIdentityColumnValues(std::back_inserter(identityColumns)); for (const auto & col : identityColumns) { if (col.first.compare(TenantAlias()) == 0) { tenant = col.second; } else if (col.first.compare(RoleAlias()) == 0) { role = col.second; } else if (col.first.compare(RoleInstanceAlias()) == 0) { roleInstance = col.second; } } } void MdsdConfig::AddEnvelopeColumn(std::string && name, std::string && value) { for (const EnvelopeColumn & column : _envelopeColumns) { if (column.first == name) { throw std::runtime_error("Column already in envelope"); } } _envelopeColumns.emplace_back(name, value); } void MdsdConfig::ForeachEnvelopeColumn(const std::function& process) { for (const EnvelopeColumn & column : _envelopeColumns) { process(column); } } TableSchema* MdsdConfig::GetSchema(const string& source) const { const auto &iter = sources.find(source); if (iter == sources.end()) { return 0; } return iter->second; } Credentials* MdsdConfig::GetCredentials(const string& moniker) const { const auto &iter = credentials.find(moniker); if (iter == credentials.end()) { return 0; } return iter->second; } std::string MdsdConfig::GetAutokey(const std::string& moniker, const std::string& fullTableName) { std::lock_guard lock(_aKMmutex); auto iter = _autoKeyMap.find(std::make_pair(moniker, fullTableName)); if (iter == _autoKeyMap.end()) { return std::string(); } return iter->second; } void MdsdConfig::DumpAutokeyTable(std::ostream &os) { os << "Dump format: " << std::endl; for (const auto & iter : _autoKeyMap) { os << "<" << iter.first.first << "," << iter.first.second << ">" << std::endl; } } bool MdsdConfig::IsQuotaExceeded(const std::string &name, unsigned long current) const { Trace trace(Trace::ConfigUse, "MdsdConfig:IsQuotaExceeded"); auto iter = _quotas.find(name); if (iter == _quotas.end()) { trace.NOTE("Check against unset quota " + name); return false; } return (current > iter->second); } void MdsdConfig::AddOmiTask(OmiTask *task) { // Defer the creation of the batch; autokey data might not yet be loaded. // The task will create the batch when an attempt is made to start it _omiTasks.push_back(task); _isUseful = true; } void MdsdConfig::ForeachOmiTask(const std::function& fn) { std::for_each(_omiTasks.begin(), _omiTasks.end(), fn); } void MdsdConfig::AddTask(ITask *task) { Trace trace(Trace::Scheduler, "MdsdConfig::AddTask"); if (trace.IsActive()) { std::ostringstream msg; msg << "Adding task " << task; trace.NOTE(msg.str()); } _tasks.push_back(task); _isUseful = true; } void MdsdConfig::ForeachTask(const std::function& fn) { Trace trace(Trace::Scheduler, "MdsdConfig::ForeachTask"); trace.NOTE("Invoking function on " + std::to_string(_tasks.size()) + " task(s)"); std::for_each(_tasks.begin(), _tasks.end(), fn); } void MdsdConfig::AddExtension(MdsdExtension * extension) { Trace trace (Trace::ConfigUse, "MdsdConfig::AddExtension"); if (!extension) { return; } const std::string& extname = extension->Name(); if (extensions.count(extname)) { AddMessage(error, "Duplicate Extension " + extname + " ignored."); delete extension; extension = nullptr; } else { extensions[extname] = extension; _isUseful = true; } } void MdsdConfig::ForeachExtension(const std::function& fn) { Trace trace (Trace::ConfigUse, "MdsdConfig::ForeachExtension"); for (const auto & kv : extensions) { trace.NOTE("Walking MdsdExtension with name='" + kv.first + "'"); fn(kv.second); } } void MdsdConfig::StartScheduledTasks() { Trace trace(Trace::Scheduler, "MdsdConfig::StartScheduledTasks"); ForeachOmiTask([](OmiTask *job) { job->Start(); }); ForeachTask([](ITask *task) { task->start(); }); } void MdsdConfig::StopScheduledTasks() { Trace trace(Trace::Scheduler, "MdsdConfig::StopScheduledTasks"); ForeachOmiTask([](OmiTask *job) { job->Cancel(); }); ForeachTask([](ITask *task) { task->cancel(); }); } // Tells this configuration to remove itself in the future. The config takes // steps immediately to stop generating work for itself, then schedules the // final cleanup action to take place after the requested delay. void MdsdConfig::SelfDestruct(int seconds) { Trace trace(Trace::ConfigUse, "MdsdConfig::SelfDestruct"); StopScheduledTasks(); StopAllTimers(); // Flush any data we're still holding on to. Don't use FlushBatches; that // will restart the autoflush timer, and we just stopped that. One last // flush will happen when the Destroyer calls delete. _batchSet.Flush(); // Create a deadline_timer on the heap; when it expires, call our Destroyer helper auto timer = new boost::asio::deadline_timer(crossplat::threadpool::shared_instance().service()); timer->expires_from_now(boost::posix_time::seconds(seconds)); timer->async_wait(boost::bind(MdsdConfig::Destroyer, this, timer)); } // This static private method does the final delete. Also deletes the heap timer. void MdsdConfig::Destroyer(MdsdConfig *config, boost::asio::deadline_timer *timer) { Trace trace(Trace::ConfigUse, "MdsdConfig:Destroyer"); std::ostringstream msg; msg << "Deleting MdsdConfig at " << config; trace.NOTE(msg.str()); delete config; delete timer; } // Create a batch for a given target. If one has already been created for that target, // return the one we're already using. Batch* MdsdConfig::GetBatch(const MdsEntityName &target, int interval) { if (target.GetStoreType() == StoreType::Local) { return _localBatches.GetBatch(target, interval); } else { return _batchSet.GetBatch(target, interval); } } bool MdsdConfig::ValidateConfig( bool isStartupConfig ) const { Trace trace(Trace::ConfigUse, "MdsdConfig::ValidateConfig"); if (!IsUseful()) { std::ostringstream msg; msg << "No productive configuration resulted from loading config file(s): " << configFilePath << "."; if (!isStartupConfig) { msg << " New configuration ignored.\n"; } msg << "Warnings detected:\n"; MessagesToStream(msg, MdsdConfig::warning); Logger::LogWarn(msg); } if (GotMessages(MdsdConfig::fatal)) { std::ostringstream msg; msg << "Fatal errors while loading configuration " << configFilePath << ":" << std::endl; MessagesToStream(msg, MdsdConfig::fatal); if (!isStartupConfig) { msg << "\nNew configuration ignored; using previous configuration"; } Logger::LogError(msg); return false; } if (GotMessages(MdsdConfig::error)) { std::ostringstream msg; msg << "Config file " << configFilePath << " parsing errors:\n"; MessagesToStream(msg, MdsdConfig::error); Logger::LogError(msg); return false; } if (GotMessages(MdsdConfig::warning)) { std::ostringstream msg; msg << "Config file " << configFilePath << "parsing warnings:\n"; MessagesToStream(msg, MdsdConfig::warning); Logger::LogWarn(msg); } return true; } void MdsdConfig::DumpEventPublisherInfo() { Trace trace(Trace::ConfigLoad, "MdsdConfig::DumpEventPublisherInfo"); if (!trace.IsActive()) { return; } if (_ehPubItemsMap.empty()) { TRACEINFO(trace, "EventPublisher map is empty"); } else { for (const auto & iter : _ehPubItemsMap) { auto moniker = iter.first; auto itemsmap = iter.second; if (itemsmap->empty()) { TRACEINFO(trace, "Moniker='" << moniker << "'; Event: N/A."); } else { for (const auto& item : (*itemsmap)) { auto eventname = item.first; auto ehinfo = item.second; TRACEINFO(trace, "Moniker='" << moniker << "'; EventName='" << eventname << "'; EHInfo: " << ehinfo); } } } } } std::string MdsdConfig::GetDefaultMoniker() const { auto defaultCreds = GetDefaultCredentials(); if (!defaultCreds) { throw std::runtime_error("No default credential is found."); } return defaultCreds->Moniker(); } void MdsdConfig::AddMonikerEventInfo( const std::string & moniker, const std::string & eventName, StoreType::Type type, const std::string & sourceName, mdsd::EventType eventType ) { Trace trace(Trace::ConfigLoad, "AddMonikerEventInfo"); try { auto monikerToUse = moniker.empty()? GetDefaultMoniker() : moniker; _mdsdEventCfg->AddEventSinkCfgInfoItem({eventName, monikerToUse, type, sourceName, eventType }); TRACEINFO(trace, "Saved event=" << eventName << " moniker=" << monikerToUse); } catch(const std::exception& ex) { AddMessage(fatal, std::string("AddMonikerEventInfo() failed: ") + ex.what()); } } void MdsdConfig::SetOboDirectPartitionFieldNameValue(std::string&& name, std::string&& value) { _oboDirectPartitionFieldsMap.emplace(name, value); if (name == "resourceId") { _resourceId = value; } } std::string MdsdConfig::GetOboDirectPartitionFieldValue(const std::string& name) const { if (name.empty()) { throw std::invalid_argument("MdsdConfig::GetOboDirectPartitionFieldValue(name): name cannot be empty"); } std::string value; auto it = _oboDirectPartitionFieldsMap.find(name); if (it != _oboDirectPartitionFieldsMap.end()) { value = it->second; } else { Logger::LogWarn("OboDirectPartitionField with name='" + name + "' not found. Make sure the mdsd.xml includes the corresponding " "Management/OboDirectPartitionField element. Returning an empty string " "as the result value."); } return value; } void MdsdConfig::ValidateEvents() { Trace trace(Trace::ConfigLoad, "MdsdConfig::ValidateEvents"); try { ValidateAnnotations(); ValidateEventHubPubKeys(); ValidateEventHubPubSinks(); } catch(const std::exception & ex) { AddMessage(error, std::string("MdsdConfig::ValidateEvents() failed: ") + ex.what()); } } void MdsdConfig::ValidateAnnotations() { for (const auto & name : _mdsdEventCfg->GetInvalidAnnotations()) { AddMessage(MdsdConfig::error, "Unknown name '" + name + "' in EventStreamingAnnotation"); } } void MdsdConfig::ValidateEventHubPubKeys() { for (const auto & publisherName : _eventPubCfg->CheckForInconsistencies(_hasAutoKey)) { AddMessage(MdsdConfig::error, "Failed to find event publisher SAS key for item '" + publisherName + "'"); } } void MdsdConfig::ValidateEventHubPubSinks() { for (const auto & publisherName: _mdsdEventCfg->GetEventPublishers()) { if (!LocalSink::Lookup(publisherName)) { AddMessage(error, "failed to find LocalSink object for Event Publisher " + publisherName); } else { _isUseful = true; // Found a valid event publisher } } } void MdsdConfig::InitEventHubPub() { Trace trace(Trace::ConfigUse, "MdsdConfig::InitEventHubPub"); SetEventHubPubForLocalSinks(); // create uploaders first before setting SAS key mdsd::EventHubUploaderMgr::GetInstance().CreateUploaders(mdsd::EventHubType::Publish, _eventPubCfg->GetNameMonikers()); SetupEventHubPubEmbeddedKeys(); } void MdsdConfig::SetupEventHubPubEmbeddedKeys() { Trace trace(Trace::ConfigUse, "MdsdConfig::SetupEventHubPubEmbeddedKeys"); auto& ehUploaderMgr = mdsd::EventHubUploaderMgr::GetInstance(); auto ehtype = mdsd::EventHubType::Publish; for (const auto & item : _eventPubCfg->GetEmbeddedSasData()) { auto & publisherName = item.first; auto & monikerSasMap = item.second; for (const auto & keyItem : monikerSasMap) { auto & moniker = keyItem.first; auto & saskey = keyItem.second; ehUploaderMgr.SetSasAndStart(mdsd::EventHubUploaderId(ehtype, moniker, publisherName), saskey); } } } void MdsdConfig::SetEventHubPubForLocalSinks() { Trace trace(Trace::ConfigUse, "MdsdConfig::SetEventHubPubForLocalSinks"); std::string tenant, role, roleInstance; GetIdentityValues(tenant, role, roleInstance); for (const auto & item : _eventPubCfg->GetNameMonikers()) { auto & publisherName = item.first; auto sinkObj = LocalSink::Lookup(publisherName); if (!sinkObj) { throw std::runtime_error("SetEventHubPubForLocalSinks(): failed to find LocalSink object for " + publisherName); } else { std::string duration = GetDurationForEventName(publisherName); auto & monikers = item.second; sinkObj->SetEventPublishInfo(monikers, std::move(duration), tenant, role, roleInstance); } } } // vim: sw=8 ================================================ FILE: Diagnostic/mdsd/mdsd/MdsdConfig.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSDCONFIG_HH_ #define _MDSDCONFIG_HH_ #include "TableSchema.hh" #include "Batch.hh" #include "IdentityColumns.hh" #include "Priority.hh" #include "EventHubCmd.hh" #include "CfgEventAnnotationType.hh" #include #include #include #include #include #include #include #include #include #include #include #include class Credentials; class OmiTask; class ITask; class MdsdExtension; namespace mdsd { struct OboDirectConfig; class EventPubCfg; class MdsdEventCfg; enum class EventType; } class MdsdConfig { public: /// /// Create an MdsdConfiguration from a configuration file /// /// Pathname of the config file to load MdsdConfig(std::string path, std::string autokeyConfigPath); ~MdsdConfig(); /// /// Initialize configuring activities, including loading autokey, /// flushing batches etc. /// void Initialize(); /// /// Load a configuration file into an existing MdsdConfiguration /// /// Pathname of the config file to load void LoadFromConfigFile(std::string path); //////////// Parser Warnings and Errors ////////// typedef enum { anySeverity=15, info = 8, warning = 4, error = 2, fatal = 1 } severity_t; /// Return the readable name of a severity code /// The severity code const std::string& SeverityToString(severity_t severity) const; /// /// Record a message (warning, error, fatal error, etc) for this location in the parse of the file. This /// method always adds a newline (\n) to the end of each message. /// /// The severity of the message (e.g. Info, Warning, Error, Fatal, etc.) /// The message to be recorded void AddMessage(severity_t severity, const std::string& msg); /// True if messages were recorded via MdsdConfig::AddMessage() bool GotMessages(int mask) const; /// Write all recorded messages to a stream void MessagesToStream(std::ostream& output, int mask) const; class Message { public: Message(const std::string& f, long l, severity_t s, const std::string& m) : filename(f), line(l), severity(s), msg(m) {} ~Message() {} std::string filename; long line; severity_t severity; std::string msg; }; ///////////// Configuration Settings ////////////// /// Indicates if some useful/productive settings have made it into the config bool IsUseful() const { return _isUseful; } /// True if element has been loaded once bool MonitoringManagementSeen() const { return _monitoringManagementSeen; } void MonitoringManagementSeen(bool state) { _monitoringManagementSeen = state; } /// Prefix for all event names. const std::string& Namespace() const { return nameSpace; } void Namespace(const std::string& name) { nameSpace = name; } /// Version suffix for event names. "5" yields a suffix of "Ver5v0". int EventVersion() const { return eventVersion; } void EventVersion(int ver) { eventVersion = ver; } /// Config file timestamp. const std::string& Timestamp() const { return timeStamp; } void Timestamp(const std::string& ts) { timeStamp = ts; } /// Number of partitions to spread across in MDS tables unsigned int PartitionCount() const { return _partitionCount; } void PartitionCount(unsigned int count) { _partitionCount = count; } /// How long to keep data in the agent unsigned long DefaultRetention() const { return _defaultRetention; } void DefaultRetention(unsigned long count) { _defaultRetention = count; } //////////// Identity /////// // Add an identity column to the set bool AddIdentityColumn(const std::string& colname, const std::string& colval); // Push name/value or name/type pairs into destination containers void GetIdentityColumnValues(std::back_insert_iterator); void GetIdentityColumnTypes(std::back_insert_iterator); // Get Tenant/Role/RoleInstance values. Return related Alias values if alias is used. void GetIdentityValues(std::string & tenant, std::string& role, std::string& roleInstance); // Aliases for the special Tenant, Role, and RoleInstance identity elements void SetTenantAlias(const std::string& name) { _tenantNameAlias = name; } void SetRoleAlias(const std::string& name) { _roleNameAlias = name; } void SetRoleInstanceAlias(const std::string& name) { _roleInstanceNameAlias = name; } const std::string& TenantAlias() const { return _tenantNameAlias; } const std::string& RoleAlias() const { return _roleNameAlias; } const std::string& RoleInstanceAlias() const { return _roleInstanceNameAlias; } const ident_vect_t * GetIdentityVector() { return &identityColumns; } const std::string & AgentIdentity() const { return _agentIdentity; } //////////// Envelope /////// using EnvelopeColumn = std::pair; void AddEnvelopeColumn(std::string && name, std::string && value); void ForeachEnvelopeColumn(const std::function&); /////////// Table Schemas and Event Sources ////////// /// Add a schema to configuration. Once invoked, the caller no longer owns the schema object. /// /// Pointer to the schema to be added. Once handed to AddSchema, the caller no longer owns the pointer. /// void AddSchema(TableSchema* schema); /// Add a source to the configuration. The source name can be mapped to an already-known schema. /// The name by which the source identities itself /// The name of the schema void AddSource(const std::string& source, const std::string& schema); bool IsValidSource(const std::string& source) { return (sources.count(source) > 0); } /// Add a source to the configuration. The source name will only be valid for dynamic schema input protocols. /// The name by which the source identities itself void AddDynamicSchemaSource(const std::string& source); bool IsValidDynamicSchemaSource(const std::string& source) { return (_dynamic_sources.count(source) > 0); } /// Get the table schema for a source; return 0 if the source is unknown /// The name by which the event source identifies itself TableSchema* GetSchema(const std::string& source) const; //////////// OMI Tasks ////////////// void AddOmiTask(OmiTask *task); void AddOmiTask(const std::string &ev, Priority pri, Credentials *creds, bool noNPD, const std::string &nmsp, const std::string &qry, time_t rate); void ForeachOmiTask(const std::function&); //////////// Arbitrary Tasks ////////////// void AddTask(ITask *task); void ForeachTask(const std::function&); //////////// Extensions ////////////// /// /// Add an extension object to configuration. Once invoked, the caller /// no longer owns the extension object. /// Pointer to the extension object. /// Once handed to AddExtension, the caller no longer owns the pointer. /// /// void AddExtension(MdsdExtension * extension); size_t GetNumExtensions() const { return extensions.size(); } void ForeachExtension(const std::function&); //////////// Credentials ////////////// /// Add a Credential to configuration. Once invoked, the caller no longer owns the creds object. /// /// Pointer to the Credentials to be added. Once handed to AddCredentials, the caller no longer owns the pointer. /// /// True if these should be the default credentials for this configuration void AddCredentials(Credentials* creds, bool makeDefault); /// Get the credentials for a moniker; return 0 if the moniker is unknown /// The moniker of the credential of interest Credentials* GetCredentials(const std::string& moniker) const; /// Get the default credentials. Returns 0 if there is no default. Credentials* GetDefaultCredentials() const { return _defaultCreds; } /// Get default moniker. Throw exception if no default is found. std::string GetDefaultMoniker() const; /// Get the autokey URI, if any, for a [moniker,tablename] pair. std::string GetAutokey(const std::string& moniker, const std::string& fullTableName); /// Get EventHub cmd XML items (currently SAS and MDS endpoint ID) /// for the moniker/eventName combination mdsd::EhCmdXmlItems GetEventNoticeCmdXmlItems(const std::string & moniker, const std::string & eventName); mdsd::EhCmdXmlItems GetEventPublishCmdXmlItems(const std::string & moniker, const std::string & eventName); ///////////// Quotas ///////////// void AddQuota(const std::string &name, unsigned long limit) { _quotas[name] = limit; } bool IsQuotaExceeded(const std::string &name, unsigned long current) const; // Record moniker, eventname, storetype, source name information. // If the input 'moniker' is empty, use the default one. // sourceName can be empty, e.g. OMIQuery. void AddMonikerEventInfo(const std::string & moniker, const std::string & eventName, StoreType::Type type, const std::string & sourceName, mdsd::EventType eventType); // Validate the events in configuration xml void ValidateEvents(); ///////// OboDirectConfig (XJsonBlob) ////////// void AddOboDirectConfig(const std::string& eventName, std::shared_ptr&& oboDirectConfig) { _oboDirectConfigsMap.emplace(eventName, std::move(oboDirectConfig)); } // Caller should catch the std::out_of_range exception if the map doesn't contain a key // matching eventName. std::shared_ptr GetOboDirectConfig(const std::string& eventName) const { return _oboDirectConfigsMap.at(eventName); } ///////////// Helpers ///////////// // Return a reference to the set of batches associated with the current config BatchSet& GetBatchSet() { return _batchSet; } // Return a reference to the batches that correspond to "local" storageType. These // survive config reloads but require considerably more care in terms of resource // management. // static BatchSet& GetLocalBatchSet() { return _localBatches; } // Given MdsEntityName and autoflush interval, find an existing // batch (in the appropriate batch set) or make one. Batch* GetBatch(const MdsEntityName &, int interval); void StopAllTimers(); void StartScheduledTasks(); void StopScheduledTasks(); void SelfDestruct(int seconds); bool ValidateConfig(bool isStartupConfig) const; std::string GetCmdContainerSas() const { return cmdContainerSas; } // key: moniker name; value: a map of key=EventName; value: EventHub cmd XML items (currently SAS and MDS endpoint) using EventHubSasInfo_t = std::unordered_map>>; void SetOboDirectPartitionFieldNameValue(std::string&& name, std::string&& value); std::string GetOboDirectPartitionFieldValue(const std::string& name) const; // Currently the VM resource ID is obtained from (and set on) Management/OboDirectParititionField element with name="resourceId" // Change this later if another better methods becomes available or if the logic needs to be changed. std::string GetResourceId() const { return _resourceId; } // Below is for metric event Json object construction purpose // Currently, only DerivedEvent's duration attributes are stored // (because currently they are the only metric events available for Azure Monitor Json blob sink) void SetDurationForEventName(const std::string& eventName, const std::string& duration) { _eventNamesDurationsMap[eventName] = duration; } std::string GetDurationForEventName(const std::string& eventName) const { auto it = _eventNamesDurationsMap.find(eventName); return it == _eventNamesDurationsMap.end() ? std::string() : it->second; } std::shared_ptr& GetMdsdEventCfg() { return _mdsdEventCfg; } std::shared_ptr& GetEventPubCfg() { return _eventPubCfg; } private: MdsdConfig(); // Disallow empty constructor // this is a record of the config file path. The file path can be renamed/moved during // 'this' MdsdConfig file time. std::string configFilePath; std::deque messages; std::string currentPath; long currentLine; int msgMask; std::string _autokeyConfigFilePath; /// Prefix for all event names. std::string nameSpace; /// Version suffix for event names. "5" yields a suffix of "Ver5v0". int eventVersion; /// Timestamp of the config file. std::string timeStamp; std::map schemas; std::map credentials; std::map sources; std::unordered_set _dynamic_sources; ident_vect_t identityColumns; std::vector _envelopeColumns; std::string _tenantNameAlias; std::string _roleNameAlias; std::string _roleInstanceNameAlias; std::vector _omiTasks; std::vector _tasks; std::map extensions; std::string _resourceId; unsigned int _partitionCount; unsigned long _defaultRetention; bool _isUseful; Credentials* _defaultCreds; BatchSet _batchSet; boost::asio::deadline_timer _batchFlushTimer; void FlushBatches(const boost::system::error_code &); static BatchSet _localBatches; std::string _agentIdentity; std::map, std::string> _autoKeyMap; std::mutex _aKMmutex; boost::asio::deadline_timer _autoKeyReloadTimer; bool LoadAutokey(const boost::system::error_code &); void DumpAutokeyTable(std::ostream &os); std::mutex _ehMapMutex; // key: (original, not mapped) moniker; // value: a map of key=EventName; value: EventHub cmd XML items (currently SAS and MDS endpoint) using EventHubItemsMap_t = std::unordered_map>>; EventHubItemsMap_t _ehNoticeItemsMap; EventHubItemsMap_t _ehPubItemsMap; mdsd::EhCmdXmlItems GetEventHubCmdXmlItems(EventHubItemsMap_t& ehmap, const std::string & moniker, const std::string & eventName, const std::string & eventType); void LoadEventHubKeys(const std::vector>& keylist); void DumpEventPublisherInfo(); void SetMappedMoniker(const EventHubSasInfo_t & ehmap); // For EventHub notice, create uploaders in EH manager, // then set the SAS key and start the uploaders. void InitEventHubNotice(); // Initialize EventHub publishers, set SAS keys void InitEventHubPub(); // Make sure each annotated event exists void ValidateAnnotations(); // Make sure each event publisher has some SAS key, either embedded, or from AutoKey void ValidateEventHubPubKeys(); // Make sure each event publisher has a LocalSink object that'll publish data for it. void ValidateEventHubPubSinks(); void SetupEventHubPubEmbeddedKeys(); void SetEventHubPubForLocalSinks(); // key: event name; value: mdsd::OboDirectConfig std::unordered_map> _oboDirectConfigsMap; // key: OboDirect partition field name (e.g., "resourceId"); value: value (e.g., "SUBSCRIPTIONS/91D12660-3DEC-467A-BE2A-213B5544DDC0/RESOURCEGROUPS/RMANDASHOERG/PROVIDERS/MICROSOFT.DEVICES/IOTHUBS/SHOEHUBSCUS3") std::unordered_map _oboDirectPartitionFieldsMap; // key: eventName (e.g., "WADMetricsPT1MP10DV2S"); value: duration (e.g., "PT1M") // Currently only DerivedEvent's durations are stored. std::unordered_map _eventNamesDurationsMap; static void Destroyer(MdsdConfig *, boost::asio::deadline_timer *); std::map _quotas; bool _monitoringManagementSeen; std::string cmdContainerSas; // true if autokey is used in any account; false otherwise. bool _hasAutoKey; std::shared_ptr _mdsdEventCfg; std::shared_ptr _eventPubCfg; /// /// Set the line number of the file being parsed. Used when recording messages generated during parsing. /// /// The line sequence number of the chunk being handed to the parser void SetLineNumber(long num) { currentLine = num; } /// Increment the line number indicator for the next chunk to be parsed void NextLine() { currentLine++; } std::vector> ExtractCmdContainerAutoKeys(); }; #endif //_MDSDCONFIG_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsdExtension.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSDEXTENSION_HH_ #define _MDSDEXTENSION_HH_ #include class MdsdExtension { public: /// Contruct new MdsdExtension object given extension name. /// Extension name /// MdsdExtension(const std::string & name) : _name(name), _cpuPercentUsage(0), _isCpuThrottling(false), _memoryLimitInMB(0), _isMemoryThrottling(false), _ioReadLimitInKBPerSecond(0), _ioReadThrottling(false), _ioWriteLimitInKBPerSecond(0), _ioWriteThrottling(false) { } ~MdsdExtension() { } const std::string & Name() const { return _name; } const std::string & GetCmdLine() const { return _cmdline; } void SetCmdLine(const std::string & cmdline) { _cmdline = cmdline; } const std::string & GetBody() const { return _body; } void SetBody(const std::string & body) { _body = body; } const std::string & GetAlterLocation() const { return _alterLocation; } void SetAlterLocation(const std::string & alterLocation) { _alterLocation = alterLocation; } float GetCpuPercentUsage() const { return _cpuPercentUsage; } void SetCpuPercentUsage(float cpuPercentUsage) { _cpuPercentUsage = cpuPercentUsage; } bool GetIsCpuThrottling() const { return _isCpuThrottling; } void SetIsCpuThrottling(bool isCpuThrottling) { _isCpuThrottling = isCpuThrottling; } unsigned long long GetMemoryLimitInMB() const { return _memoryLimitInMB; } void SetMemoryLimitInMB(unsigned long long memoryLimitInMB) { _memoryLimitInMB = memoryLimitInMB; } bool GetIsMemoryThrottling() const { return _isMemoryThrottling; } void SetIsMemoryThrottling(bool isMemoryThrottling) { _isMemoryThrottling = isMemoryThrottling; } unsigned long long GetIOReadLimitInKBPerSecond() const { return _ioReadLimitInKBPerSecond; } void SetIOReadLimitInKBPerSecond(unsigned long long n) { _ioReadLimitInKBPerSecond = n; } bool GetIsIOReadThrottling() const { return _ioReadThrottling; } void SetIsIOReadThrottling(bool isThrottling) { _ioReadThrottling = isThrottling; } unsigned long long GetIOWriteLimitInKBPerSecond() const { return _ioWriteLimitInKBPerSecond; } void SetIOWriteLimitInKBPerSecond(unsigned long long n) { _ioWriteLimitInKBPerSecond = n; } bool GetIsIOWriteThrottling() const { return _ioWriteThrottling; } void SetIsIOWriteThrottling(bool isThrottling) { _ioWriteThrottling = isThrottling; } private: MdsdExtension() = delete; const std::string _name; // Define command line to be std::string because we need to execute it. std::string _cmdline; std::string _body; // Define alternative location path to be std::string because we need to use the path for execute. std::string _alterLocation; float _cpuPercentUsage; bool _isCpuThrottling; unsigned long long _memoryLimitInMB; bool _isMemoryThrottling; unsigned long long _ioReadLimitInKBPerSecond; bool _ioReadThrottling; unsigned long long _ioWriteLimitInKBPerSecond; bool _ioWriteThrottling; }; #endif // _MDSDEXTENSION_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/MdsdMetrics.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsdMetrics.hh" thread_local MdsdMetrics *MdsdMetrics::_instance = nullptr; std::unordered_set MdsdMetrics::_instances; std::mutex MdsdMetrics::_setLock; std::map MdsdMetrics::AggregateAll() { std::map totals; for (const MdsdMetrics * pinstance : _instances) { for (const auto & item : pinstance->_metrics) { totals[item.first] += item.second; } } return totals; } unsigned long long MdsdMetrics::AggregateMetric(const std::string &metric) { unsigned long long total = 0; for (const MdsdMetrics * pinstance : _instances) { auto & inst_map = pinstance->_metrics; auto iter = inst_map.find(metric); if (iter != inst_map.end()) { total += iter->second; } } return total; } #ifdef DOING_MEMCHECK bool MdsdMetrics::_allFree = false; void MdsdMetrics::ClearMetrics() { std::lock_guard lock(_setLock); _allFree = true; for (MdsdMetrics *item : _instances) { delete item; } _instances.clear(); } #endif // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/MdsdMetrics.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _MDSDMETRICS_HH_ #define _MDSDMETRICS_HH_ #include #include #include #include #ifdef DOING_MEMCHECK #define QUITABORT if (_allFree) return #else #define QUITABORT #endif class MdsdMetrics { public: static MdsdMetrics &GetInstance() { if (_instance == nullptr) { std::lock_guard lock(_setLock); _instance = new MdsdMetrics(); _instances.insert(_instance); } return *_instance; } static void Count(const std::string &metric) { QUITABORT; GetInstance().CountThis(metric); } static void Count(const std::string &metric, unsigned long long delta) { QUITABORT; GetInstance().CountThis(metric, delta); } void CountThis(const std::string &metric) { QUITABORT; _metrics[metric]++; } void CountThis(const std::string &metric, unsigned long long delta) { QUITABORT; _metrics[metric] += delta; } static std::map AggregateAll(); static unsigned long long AggregateMetric(const std::string &metric); private: // One instance in each thread; that makes access within a thread lock-free static thread_local MdsdMetrics * _instance; // One global list of all per-thread instances... static std::unordered_set _instances; // One lock protects the global list static std::mutex _setLock; std::map _metrics; #ifdef DOING_MEMCHECK public: void ClearMetrics(); static bool _allFree; private: #endif MdsdMetrics() {} }; #endif // _MDSDMETRICS_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Memcheck.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifdef DOING_MEMCHECK #include "MdsSchemaMetadata.hh" #include "Engine.hh" #include "Logger.hh" #include "Trace.hh" #include // Only compiled when DOING_MEMCHECK is added as a -D on the compile line // Otherwise, it won't even compile, much less link. extern "C" void RunFinalCleanup() { Trace trace(Trace::SignalHandlers, "RunFinalCleanup"); trace.NOTE("Clear Schema Metadata Cache"); MdsSchemaMetadata::ClearCache(); trace.NOTE("Clear Extension object cache"); CleanupExtensions(); Engine* engine = Engine::GetEngine(); trace.NOTE("Clear SchemasTable cache"); engine->ClearPushedCache(); trace.NOTE("Cleanup MdsdConfig"); engine->ClearConfiguration(); engine = nullptr; // Must be last trace.NOTE("Closing all logs"); Logger::CloseAllLogs(); VALGRIND_DO_LEAK_CHECK; } #endif // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/OMIQuery.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "OMIQuery.hh" #include "OmiTask.hh" #include "MdsSchemaMetadata.hh" #include "MdsValue.hh" #include "Trace.hh" #include "Logger.hh" #include "CanonicalEntity.hh" #include "Engine.hh" #include "MdsdConfig.hh" #include "MdsEntityName.hh" #include "Credentials.hh" #include "Batch.hh" #include "Utility.hh" #include #include #include #include #include #include #include #include extern "C" { #include } using std::vector; using std::pair; OMIQuery::OMIQuery(PipeStage *head, const std::string& ns, const std::string& qry, bool doUpload) : _pipeHead(head), _name_space(ns), _queryexpr(qry), _uploadData(doUpload), _connTimeoutMS(90000) { Trace trace(Trace::OMIIngest, "OMIQuery::Constructor"); if (MdsdUtil::NotValidName(_name_space)) { throw std::invalid_argument("OMIQuery namespace must not be empty"); } else if (MdsdUtil::IsEmptyOrWhiteSpace(_queryexpr)) { throw std::invalid_argument("OMIQuery query expression must not be empty"); } const std::string from = "from"; const std::string& whitespace = " \t"; std::string queryLower = _queryexpr; std::transform(queryLower.begin(), queryLower.end(), queryLower.begin(), ::tolower); size_t frompos = queryLower.find(from); if (std::string::npos == frompos) { throw std::invalid_argument("Invalid syntax in OMI query expression (invalid class name specification)"); } std::string substr1 = _queryexpr.substr(frompos + from.length()); auto strBegin = substr1.find_first_not_of(whitespace); if (std::string::npos == strBegin) { throw std::invalid_argument("Invalid syntax in OMI query expression (invalid class name specification)"); } auto strEnd = substr1.find_first_of(whitespace, strBegin); _classname = substr1.substr(strBegin, strEnd-strBegin); if (MdsdUtil::NotValidName(_classname)) { throw std::invalid_argument("OMIQuery class must not be empty"); } _schemaId = OmiTask::SchemaId(ns, qry); if (0 == _schemaId) { throw std::invalid_argument("No schemaID has been allocated for this namespace and query"); } trace.NOTE("Query namespace(" + _name_space + ") class(" + _classname + ") queryexp(" + _queryexpr + ")"); } OMIQuery::~OMIQuery() { Trace trace(Trace::OMIIngest, "OMIQuery::Destructor"); // Leave the processing pipeline alone; the OmiTask will have cleaned it up. } void OMIQuery::SetConnTimeout(unsigned int milliSeconds) { Trace trace(Trace::OMIIngest, "OMIQuery::SetConnTimeout"); _connTimeoutMS = milliSeconds; trace.NOTE("Set OMI connection timeout(MS)=" + std::to_string(_connTimeoutMS)); } std::unique_ptr OMIQuery::CreateNewClient() { Trace trace(Trace::OMIIngest, "OMIQuery::CreateNewClient"); bool resultOk = true; std::unique_ptr client; // Points to nothing try { client.reset(new mi::Client()); // Make it own the newly allocated object mi::String locator = SCX_SOCKET_VAL; if (trace.IsActive()) { std::ostringstream msg; msg << "locator='" << locator.Str() << "'; Timeout(MS)=" << _connTimeoutMS; trace.NOTE(msg.str()); } resultOk = client->Connect(locator, "", "", _connTimeoutMS*1000); if (!resultOk) { LogError("Error: Unable to connect to OMI service. (Is OMI installed and started?)"); client.reset(); } } catch(...) { LogError("Error: Exception thrown while connecting to OMI service. (Is OMI functional?)"); client.reset(); // Deletes what it pointed to, if anythinq resultOk = false; } trace.NOTE("ResultStatus=" + std::to_string(resultOk)); return client; } bool OMIQuery::NoOp() { Trace trace(Trace::OMIIngest, "OMIQuery::NoOp"); bool resultOK = true; try { auto client = CreateNewClient(); if (client) { resultOK = client->NoOp(_connTimeoutMS*1000); if (!resultOK) { LogError("Error: OMI NoOp() failed. Is OMI functional?"); } else { trace.NOTE("NoOp finished Successfully."); } client->Disconnect(); } else { resultOK = false; } } catch(...) { LogError("Error: Exception thrown while performing OMI NoOp" ); resultOK = false; } return resultOK; } // Execute the query; put the results in CanonicalEntity instances and // pass them into the processing pipeline associated with this query. bool OMIQuery::RunQuery(const MdsTime& qibase) { Trace trace(Trace::OMIIngest, "OMIQuery::RunQuery"); trace.NOTE("\nrun query: " + _name_space + " : " + _queryexpr); bool resultOK = true; MdsTime queryTime; // Default constructor sets this to the current time mi::Array instanceList; mi::Result result = MI_RESULT_OK; try { auto client = CreateNewClient(); if (! client) { return false; } resultOK = client->EnumerateInstances(_name_space.c_str(), _classname.c_str(), true, _connTimeoutMS*1000, instanceList, QUERYLANG, _queryexpr.c_str(), result); if (!resultOK || (result != MI_RESULT_OK)) { LogError("Error: OMI EnumerateInstances failed"); resultOK = false; } client->Disconnect(); } catch(const std::exception& e) { resultOK = false; LogError("Error: OMI RunQuery() unexpected exception: " + std::string(e.what())); } catch(...) { resultOK = false; LogError(std::string("Error: OMI RunQuery() unexpected exception:")); } if (resultOK) { MI_Uint32 count = instanceList.GetSize(); trace.NOTE("Found instances count=" + std::to_string(count)); _pipeHead->Start(qibase); for (MI_Uint32 i = 0; i < count; i++) { CanonicalEntity * ce = new CanonicalEntity(instanceList[i].Count()); resultOK = PopulateEntity(ce, instanceList[i]); if (resultOK) { // Suppress a CanonicalEntity with zero columns; could happen, means // nothing is wrong, just no data if (ce->size()) { ce->SetPreciseTime(queryTime); ce->SetSchemaId(_schemaId); _pipeHead->Process(ce); } else { delete ce; } } else { Logger::LogWarn("Problem(s) detected with this OMI instance; dropping it"); delete ce; } } _pipeHead->Done(); } trace.NOTE("RunQuery finished with resultOK=" + std::to_string(resultOK)); return resultOK; } bool OMIQuery::PopulateEntity(CanonicalEntity *ce, const mi::DInstance& item) { Trace trace(Trace::OMIIngest, "OMIQuery::PopulateEntity"); mi::Uint32 count = item.Count(); trace.NOTE("Instance has #items=" + std::to_string(count)); try { for (mi::Uint32 i = 0; i < count; i++) { mi::String name; if (!item.GetName(i, name)) { LogError("While processing OMI results, failed to get name of column " + std::to_string(i)); return false; } std::string namestr (name.Str()); mi::Type type; MI_Value value; bool isNull = false; bool isKey = false; if (!item.GetValue(name, &value, type, isNull, isKey)) { LogError("While processing OMI results, failed to get value for column " + std::to_string(i)); return false; } if (isNull) { ce->AddColumn(namestr, "[NULL]"); if (trace.IsActive()) { std::ostringstream msg; msg << "Item[" << i << "]: " << namestr << " (OMI type " << type << ") is NULL"; trace.NOTE(msg.str()); } } else if (type == MI_INSTANCE || type == MI_REFERENCE) { trace.NOTE("Item[" + std::to_string(i) + "] is an Instance/Reference"); mi::DInstance subitem; bool resultOK; if (type == MI_INSTANCE) { resultOK = item.GetInstance(name, subitem); } else { resultOK = item.GetReference(name, subitem); } resultOK = resultOK && PopulateEntity(ce, subitem); if (!resultOK) { LogError("While processing OMI results, failed to unpack instance/reference"); return false; } } else { std::ostringstream msg; bool resultOK = true; msg << "Item[" << i << "]: " << namestr << " (MI_Type " << type << ")"; try { MdsValue * mdsvalue = new MdsValue { value, type }; ce->AddColumn(namestr, mdsvalue); msg << " " << mdsvalue->TypeToString() << " " << *mdsvalue; } catch (std::exception & e) { resultOK = false; msg << " failed type conversion (" << e.what() << ")"; } trace.NOTE(msg.str()); if (!resultOK) return false; } } return true; } catch (...) { LogError("Unknown exception caught in OMIQuery::PopulateEntity"); } return false; } namespace std { template <> struct hash { size_t operator()(const MI_Result & res) const { return static_cast(res); } }; } std::string OMIQuery::Result_ToString(MI_Result result) const { static std::unordered_map resultCodes = { { MI_RESULT_OK, "MI_RESULT_OK" }, { MI_RESULT_FAILED, "MI_RESULT_FAILED" }, { MI_RESULT_ACCESS_DENIED, "MI_RESULT_ACCESS_DENIED" }, { MI_RESULT_INVALID_NAMESPACE, "MI_RESULT_INVALID_NAMESPACE" }, { MI_RESULT_INVALID_PARAMETER, "MI_RESULT_INVALID_PARAMETER" }, { MI_RESULT_INVALID_CLASS, "MI_RESULT_INVALID_CLASS" }, { MI_RESULT_NOT_FOUND, "MI_RESULT_NOT_FOUND" }, { MI_RESULT_NOT_SUPPORTED, "MI_RESULT_NOT_SUPPORTED" }, { MI_RESULT_CLASS_HAS_CHILDREN, "MI_RESULT_CLASS_HAS_CHILDREN" }, { MI_RESULT_CLASS_HAS_INSTANCES, "MI_RESULT_CLASS_HAS_INSTANCES" }, { MI_RESULT_INVALID_SUPERCLASS, "MI_RESULT_INVALID_SUPERCLASS" }, { MI_RESULT_ALREADY_EXISTS, "MI_RESULT_ALREADY_EXISTS" }, { MI_RESULT_NO_SUCH_PROPERTY, "MI_RESULT_NO_SUCH_PROPERTY" }, { MI_RESULT_TYPE_MISMATCH, "MI_RESULT_TYPE_MISMATCH" }, { MI_RESULT_QUERY_LANGUAGE_NOT_SUPPORTED, "MI_RESULT_QUERY_LANGUAGE_NOT_SUPPORTED" }, { MI_RESULT_INVALID_QUERY, "MI_RESULT_INVALID_QUERY" }, { MI_RESULT_METHOD_NOT_AVAILABLE, "MI_RESULT_METHOD_NOT_AVAILABLE" }, { MI_RESULT_METHOD_NOT_FOUND, "MI_RESULT_METHOD_NOT_FOUND" }, { MI_RESULT_NAMESPACE_NOT_EMPTY, "MI_RESULT_NAMESPACE_NOT_EMPTY" }, { MI_RESULT_INVALID_ENUMERATION_CONTEXT, "MI_RESULT_INVALID_ENUMERATION_CONTEXT" }, { MI_RESULT_INVALID_OPERATION_TIMEOUT, "MI_RESULT_INVALID_OPERATION_TIMEOUT" }, { MI_RESULT_PULL_HAS_BEEN_ABANDONED, "MI_RESULT_PULL_HAS_BEEN_ABANDONED" }, { MI_RESULT_PULL_CANNOT_BE_ABANDONED, "MI_RESULT_PULL_CANNOT_BE_ABANDONED" }, { MI_RESULT_FILTERED_ENUMERATION_NOT_SUPPORTED, "MI_RESULT_FILTERED_ENUMERATION_NOT_SUPPORTED" }, { MI_RESULT_CONTINUATION_ON_ERROR_NOT_SUPPORTED, "MI_RESULT_CONTINUATION_ON_ERROR_NOT_SUPPORTED" }, { MI_RESULT_SERVER_LIMITS_EXCEEDED, "MI_RESULT_SERVER_LIMITS_EXCEEDED" }, { MI_RESULT_SERVER_IS_SHUTTING_DOWN, "MI_RESULT_SERVER_IS_SHUTTING_DOWN" }, { MI_RESULT_CANCELED, "MI_RESULT_CANCELED" }, { MI_RESULT_OPEN_FAILED, "MI_RESULT_OPEN_FAILED" }, { MI_RESULT_INVALID_CLASS_HIERARCHY, "MI_RESULT_INVALID_CLASS_HIERARCHY" }, { MI_RESULT_WOULD_BLOCK, "MI_RESULT_WOULD_BLOCK" }, { MI_RESULT_TIME_OUT, "MI_RESULT_TIME_OUT" } }; auto const & iter = resultCodes.find(result); if (iter != resultCodes.end()) { return std::string(iter->second); } /* Not found! */ return std::string("MI_ERROR_CODE_") + std::to_string(result); } // vim: set sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/OMIQuery.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _OMIQUERY_HH_ #define _OMIQUERY_HH_ #include "MI.h" #include "omiclient/client.h" #include "Logger.hh" #include "MdsEntityName.hh" #include "Pipeline.hh" #include "SchemaCache.hh" #include #include #include #include #include #include class MdsdConfig; class MdsValue; class CanonicalEntity; class Credentials; class Batch; /* OMIQuery provides the APIs to query OMI providers (example: SCX) and upload the results to MDS. Example usage: OMIQuery* q = new OMIQuery(parameters); bool isOK1 = q.RunQuery(...) // query 1 bool isOK2 = q.RunQuery(...) // query 2 To run in multi-threading mode, create multiple OMIQuery objects. */ typedef std::vector> omi_schemalist_t; typedef std::unordered_map omi_datatable_t; class OMIQuery { public: // Create an OMIQuery object. If uploadData is true, the data will be // uploaded to MDS azure tables. if uploadData is false, data won't be uploaded. OMIQuery(PipeStage * head, const std::string& name_space, const std::string& queryexpr, bool uploadData = true); // Release OMI server connection resources ~OMIQuery(); // disable copy and move contructors OMIQuery(OMIQuery&& h) = delete; OMIQuery& operator=(OMIQuery&& h) = delete; OMIQuery(const OMIQuery&) = delete; OMIQuery& operator=(const OMIQuery &) = delete; // Run a noop query. This can be used to test the connection to server. // Return true if success; return false for any failure. bool NoOp(); // Run an OMI query in given namespace with given query expression. // Example: name_space = "root/scx", queryexpr = "select Name from SCX_UnixProcess" // Return true if success; return false for any failure. // Puts the results into CanonicalEntity objects, which it passes to the head of // the processing pipeline (_pipehead). bool RunQuery(const MdsTime&); // Set the connection timeout value in milliSeconds. void SetConnTimeout(unsigned int milliSeconds); // Enable/disable uploading of data to MDS void EnableUpload(bool flag) { _uploadData = flag; } private: void LogError(const std::string &msg) const { Logger::LogError(msg); } std::unique_ptr CreateNewClient(); // Given an OMI instance, add its columns to a CanonicalEntity. The function will // recursively add the columns of any instances or references included within the instance. bool PopulateEntity(CanonicalEntity *, const mi::DInstance&); std::string GetClassNameFromQuery(const std::string& queryexpr) const; std::string Result_ToString(MI_Result result) const; // The top stage of the processing pipeline. All information about the ultimate destination // of each OMI record we capture is embedded in the various stages of the pipeline, which was // constructed when the config file was loaded. PipeStage * _pipeHead; std::string _name_space; // OMI namespace std::string _queryexpr; // OMI query (written in CQL) std::string _classname; // Name of the OMI class from which the query pulls data bool _uploadData; // If false, run query but don't upload data. Good for testing query itself. unsigned int _connTimeoutMS; // timeout in milli-seconds to connect to OMI server for queries. SchemaCache::IdType _schemaId; // Identifies the schema for this query // Because same queries are going to be run again and again, use cache to save the schemas. // key=querynamespace+queryexpr; value: bool. If the query exists in the table, the // schema shouldn't be uploaded any long. std::mutex tablemutex; std::mutex enginemutex; const char * SCX_SOCKET_KEY = "socketfile"; const char * SCX_SOCKET_VAL = "/var/opt/omi/run/omiserver.sock"; const char * QUERYLANG = "CQL"; }; #endif // vim: se sw=8: ================================================ FILE: Diagnostic/mdsd/mdsd/OmiTask.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "OmiTask.hh" #include "OMIQuery.hh" #include "Batch.hh" #include "Logger.hh" #include "Trace.hh" #include "SchemaCache.hh" #include class MdsdConfig; std::map OmiTask::_qryToSchemaId; OmiTask::OmiTask(MdsdConfig *config, const MdsEntityName &target, Priority prio, const std::string& nmspc, const std::string& qry, time_t sampleRate) : _config(config), _target(target), _priority(prio), _namespace(nmspc), _query(qry), _sampleRate(sampleRate?sampleRate:prio.Duration()), _retryCount(0), _timer(crossplat::threadpool::shared_instance().service()), _cancelled(false), _omiConn(nullptr), _head(nullptr), _tail(nullptr) { Trace trace(Trace::OMIIngest, "OmiTask Constructor"); if (nmspc.empty() || qry.empty()) { throw std::invalid_argument("Missing at least one required attribute (omiNamespace, cqlQuery)"); } // Allocated a schemaId for this namespace+query, if necessary std::string mapkey = _namespace + _query; if (0 == _qryToSchemaId.count(mapkey)) { // Not found - insert _qryToSchemaId[mapkey] = SchemaCache::Get().GetId(); } } OmiTask::~OmiTask() { // Cleanup the query object if (_omiConn != nullptr) { delete _omiConn; } // Cleanup the processing pipeline for query results. Cleanup is recursive; each stage // deletes its successor before completing its own cleanup. if (_head) { delete _head; _head = nullptr; } } SchemaCache::IdType OmiTask::SchemaId(const std::string & ns, const std::string & qry) { const auto & iter = _qryToSchemaId.find(ns+qry); if (iter == _qryToSchemaId.end()) { return 0; // Not found } else { return iter->second; } } void OmiTask::AddStage(PipeStage *stage) { Trace trace(Trace::QueryPipe, "OmiTask::AddStage"); if (trace.IsActive()) { std::ostringstream msg; msg << "OmiTask " << this << " adding stage " << stage->Name(); trace.NOTE(msg.str()); } if (! _tail) { // This is the first stage in the pipeline; set the head to point here _head = stage; } else { // There's already a pipeline; make the old tail point to the newly-added stage _tail->AddSuccessor(stage); } // Either way, we have a new tail in the pipeline _tail = stage; } void OmiTask::Start() { using namespace boost::posix_time; Trace trace(Trace::OMIIngest, "OmiTask::Start"); trace.NOTE(_query); if (!_head) { Logger::LogError("No processing pipeline for event; should never happen"); return; } // The OMIQuery object does all the retrieval work try { _omiConn = new OMIQuery(_head, _namespace, _query, true); } catch (const std::exception& ex) { std::ostringstream msg; msg << "Query task (" << _query << ") not started because OMIQuery creation failed: " << ex.what(); Logger::LogError(msg.str()); return; } _firstTimeTaskStartTried.Touch(); TryToStartAndRetryIfFailed(boost::system::error_code()); } void OmiTask::TryToStartAndRetryIfFailed(const boost::system::error_code& error) { using namespace boost::posix_time; Trace trace(Trace::OMIIngest, "OmiTask::TryToStartAndRetryIfFailed"); if (error == boost::asio::error::operation_aborted) { // Same comments as in OmiTask::DoWork() applies here as well. trace.NOTE("Timer cancelled"); return; } if (_omiConn->NoOp()) { // Add some randomness to when we start regular queries MdsTime target { MdsTime::Now() + MdsTime(2 + random()%5, random()%1000000) }; _qibase = target.Round(_sampleRate); _nextTime = target.to_ptime(); _timer.expires_at(_nextTime); _timer.async_wait(boost::bind(&OmiTask::DoWork, this, boost::asio::placeholders::error)); if (_retryCount > 0) { Logger::LogInfo("Query task(" + _query + ") started after " + std::to_string(_retryCount) + " retries"); } return; } // OMI noop query failed const time_t maxRetryTimeSec = 30 * 60; // Retry up to 30 minutes if (MdsTime::Now() > _firstTimeTaskStartTried + maxRetryTimeSec) { Logger::LogError(std::string("Can't connect to OMI server for more than ") .append(std::to_string(maxRetryTimeSec / 60)).append(" minutes. Giving up.")); return; } // Keep retrying yet with exponential back-off delays const time_t retryIntervalSec = 10 * (1 << _retryCount); // Exponential back-off delay (starting from 10 seconds) trace.NOTE(std::string("OMIQuery::NoOp() basic query failed. Will try to start the task again in ") .append(std::to_string(retryIntervalSec)).append(" seconds.")); Logger::LogError("Connection to OMI server failed; query task (" + _query + ") not started. Will try to start the task again in " + std::to_string(retryIntervalSec) + " seconds."); _timer.expires_from_now(boost::posix_time::seconds(retryIntervalSec)); _timer.async_wait(boost::bind(&OmiTask::TryToStartAndRetryIfFailed, this, boost::asio::placeholders::error)); _retryCount++; } void OmiTask::Cancel() { Trace trace(Trace::OMIIngest, "OmiTask::Cancel"); std::lock_guard lock(_mutex); _cancelled = true; _timer.cancel(); } void OmiTask::DoWork(const boost::system::error_code& error) { Trace trace(Trace::OMIIngest, "OmiTask::DoWork"); if (error == boost::asio::error::operation_aborted) { // If the timer was cancelled, we have to assume the entire configuration may have been // deleted; don't touch it. When an MdsdConfig object is told to self-destruct, it first // cancels all timer-driven actions, then it waits some period of time, then it actually // deletes the object. When the timers are cancelled, the handlers are called with the // cancellation message. The MdsdConfig object is *probably* still valid, and as long // as the timer isn't rescheduled, all should be well. But I'm playing it safe here // and assuming an explicit cancel operation means "the config is gone". // // Of course, if the MdsdConfig is deleted, all the associated objects, including this // very OmiTask object, get deleted as well. Thus, the "don't touch nothin'" rule. trace.NOTE("Timer cancelled"); return; } // Note that, as written, we do NOT hold the lock here; our use of the class instance // needs to be readonly. If that changes, revisit this locking pattern. _omiConn->RunQuery(_qibase); trace.NOTE("Back from RunQuery"); std::lock_guard lock(_mutex); if (error || _cancelled) { return; } trace.NOTE("Rescheduling " + _query); _qibase += MdsTime(_sampleRate); _nextTime = _nextTime + boost::posix_time::seconds(_sampleRate); _timer.expires_at(_nextTime); _timer.async_wait(boost::bind(&OmiTask::DoWork, this, boost::asio::placeholders::error)); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/OmiTask.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _OMITASK_HH_ #define _OMITASK_HH_ #include #include #include #include #include #include #include "Priority.hh" #include "MdsTime.hh" #include "MdsEntityName.hh" #include "StoreType.hh" #include "Pipeline.hh" #include "SchemaCache.hh" class OMIQuery; class MdsdConfig; class OmiTask { public: OmiTask(MdsdConfig *config, const MdsEntityName& target, Priority prio, const std::string& nmspc, const std::string& qry, time_t sampleRate); // I want a move constructor... OmiTask(OmiTask &&orig); // But do not want a copy constructor nor a default constructor OmiTask(OmiTask &) = delete; OmiTask() = delete; ~OmiTask(); // void AddUnpivot(const std::string &valueAttrName, const std::string &nameAttrName, const std::string &unpivotColumns); void AddStage(PipeStage *stage); void Start(); void Cancel(); const MdsEntityName & Target() const { return _target; } int FlushInterval() const { return _priority.Duration(); } static SchemaCache::IdType SchemaId(const std::string & ns, const std::string &qry); private: MdsdConfig *_config; MdsEntityName _target; Priority _priority; std::string _namespace; std::string _query; time_t _sampleRate; size_t _retryCount; MdsTime _firstTimeTaskStartTried; std::mutex _mutex; boost::asio::deadline_timer _timer; boost::posix_time::ptime _nextTime; bool _cancelled; MdsTime _qibase; static std::map _qryToSchemaId; // You may wonder "why is this allocated on the heap?" // Earlier in development, mdsd used a glib-based XML parser which returned Glib::ustring objects // instead of std::string. Various configuration classes stored those ustring objects and thus needed // the Glib-2.0 headers, which #define TRUE and FALSE; so, unfortunately, do the OMI headers. The // easiest solution to the compiler whining was to keep the OMI headers out of the MdsdConfig headers. // Using a pointer let us achieve that isolation. // In December 2015 we removed all use of Glib, so this was no longer an issue. At the time we // made that change, we decided it was safer to leave this as-is and clean it up in a subsequent // refactoring pass. OMIQuery *_omiConn; PipeStage *_head; PipeStage *_tail; void TryToStartAndRetryIfFailed(const boost::system::error_code& error); void DoWork(const boost::system::error_code& error); }; #endif // _OMITASK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/PipeStages.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Logger.hh" #include "Trace.hh" #include "PipeStages.hh" #include "Batch.hh" #include "CanonicalEntity.hh" #include "IdentityColumns.hh" #include "Credentials.hh" #include "MdsdConfig.hh" #include "MdsSchemaMetadata.hh" #include "StoreType.hh" #include "Utility.hh" #include "MdsTime.hh" #include namespace Pipe { const std::string Unpivot::_name { "Unpivot" }; Unpivot::Unpivot(const std::string &valueName, const std::string &nameName, const std::string &columns, std::unordered_map&& transforms) : _valueName(valueName), _nameName(nameName), _transforms(transforms) { Trace trace(Trace::QueryPipe, "Unpivot constructor"); typedef boost::tokenizer > tokenizer_t; boost::char_separator delim(", "); // space and comma tokenizer_t tokens(columns, delim); for (const auto &item : tokens) { _columns.insert(item); } if (_columns.empty()) { throw std::invalid_argument("No column names specified for "); } else if (_valueName.empty()) { throw std::invalid_argument("Invalid name for unpivot value"); } else if (_nameName.empty()) { throw std::invalid_argument("Invalid name for unpivot name column"); } if (trace.IsActive()) { std::ostringstream msg; msg << "Unpivoting these columns: "; for (const std::string &name : _columns) { msg << "[" << name; const auto& iter = _transforms.find(name); if (iter != _transforms.end()) { ColumnTransform& xform = iter->second; msg << " --> " << xform.Name; if (xform.Scale != 1.0) { msg << " scale " << xform.Scale; } } msg << "]"; } trace.NOTE(msg.str()); } } // Tear apart the input item to produce multiple output items. void Unpivot::Process(CanonicalEntity *item) { Trace trace(Trace::QueryPipe, "Unpivot::Process"); // 1: Run through the item and build a master CanonicalEntity which has only the // columns that are *not* to be unpivoted. Count the pivoted columns. CanonicalEntity master; master.SetPreciseTime(item->GetPreciseTimeStamp()); unsigned pivotCount = 0; for (auto col = item->begin(); col != item->end(); col++) { if (_columns.count(col->first)) { pivotCount++; } else { master.AddColumn(col->first, col->second); // col->second is an MdsValue* and it's now owned by master; // update item's ownership col->second = nullptr; } } // 2: If there were no pivoted columns, emit a warning, drop it on the floor, // and return. (If we needed to send it to the pipeline, we'd have to dupe // it from master into a heap-allocated copy and send that, since we'd // already have torn all of the columns into master.) if (!pivotCount) { std::ostringstream msg; msg << " matched no columns for this event: " << *item; Logger::LogWarn(msg); delete item; return; } if (trace.IsActive()) { std::ostringstream msg; msg << "Unpivoting " << pivotCount << " columns."; trace.NOTE(msg.str()); } // 3: Run through the item again. Each time a column-to-be-unpivoted is found, duplicate // the master CE, add the "name" and "value" columns, and send the row to our // successor. Apply any translations to "name" at this time. for (auto col = item->begin(); col != item->end(); col++) { if (_columns.count(col->first)) { CanonicalEntity *ce = new CanonicalEntity { master }; const auto & iter = _transforms.find(col->first); if (iter == _transforms.end()) { // No transform; use as-is ce->AddColumn(_nameName, col->first); } else { // iter points to a pair // So iter->second is a ColumnTransform ColumnTransform& xform = iter->second; ce->AddColumn(_nameName, xform.Name); // Apply the scale factor stored in the transform. MdsValue::scale() does appropriate // type conversion and does nothing, silently, if the value is not numeric. col->second->scale(xform.Scale); } ce->AddColumn(_valueName, col->second); // col->second is an MdsValue* and it's now owned by the dupe ce; // update item's ownership col->second = nullptr; PipeStage::Process(ce); } } // 4. At this point we're done with the original item, which itself is not forwarded // down the pipeline. Delete it. delete item; } const std::string BatchWriter::_name { "BatchWriter" }; BatchWriter::BatchWriter(Batch * b, const ident_vect_t * idvec, unsigned int pcount, StoreType::Type storeType) : _batch(b), _idvec(idvec), _identString(), _storeType(storeType) { std::vector identValues; bool firstTime = true; for (const auto &iter : *(_idvec)) { if (!firstTime) _identString.append("___"); _identString.append(iter.second); firstTime = false; } // If the CanonicalEntity has identity columns, it may need partition and row keys. // The identity column data is sufficient to form the standard MDS partition and row keys, // which we do here. Only the data sink knows whether these keys are actually needed. _Nstr = MdsdUtil::ZeroFill(MdsdUtil::EasyHash(_identString) % (unsigned long long)pcount, 19); } // End of the processing pipeline. Adding the item to a batch is defined as a "copy" operation, // so we should throw away the "original" after that. void BatchWriter::Process(CanonicalEntity *item) { Trace trace(Trace::QueryPipe, "BatchWriter::Process"); // Based on the target store type, ensure the proper keys are set if (_storeType == StoreType::XTable) { trace.NOTE("Adding XTable columns"); bool doDefaultColumns = false; std::string rowIndex = MdsdUtil::ZeroFill(RowIndex::Get(), 19); if (item->PartitionKey().empty()) { item->AddColumn("PartitionKey", _Nstr + "___" + MdsdUtil::ZeroFill(_qibase.to_DateTime(), 19)); doDefaultColumns = true; } if (item->RowKey().empty()) { item->AddColumn("RowKey", _identString +"___" + rowIndex); doDefaultColumns = true; } if (doDefaultColumns) { item->AddColumn("PreciseTimeStamp", new MdsValue(item->GetPreciseTimeStamp())); item->AddColumn("N", _Nstr); item->AddColumn("RowIndex", rowIndex); } item->AddColumn("TIMESTAMP", new MdsValue(_qibase)); } _batch->AddRow(*item); delete item; } // Let the batch know we're done writing for now void BatchWriter::Done() { _batch->Flush(); } const std::string Identity::_name { "Identity" }; // Add identity columns to a CanonicalEntity void Identity::Process(CanonicalEntity *item) { std::vector identValues; for (const auto &iter : *(_idvec)) { item->AddColumn(iter.first, iter.second); identValues.push_back(iter.second); } PipeStage::Process(item); } const std::string BuildSchema::_name { "BuildSchema" }; // Track which event schemas have been pushed to the appropriate central SchemasTable // This unordered set tracks the pushed schemas. The key is a string with these components // separated by single forward slashes ("/"): // MDS account moniker (*not* XStore account name) // Full table name (augmented by namespace prefix and NDay suffix as appropriate) // MD5 checksum of the canonicalized schema // This cache is global and never reset (except by agent restart). std::unordered_set BuildSchema::_pushedSchemas; // The "target" metadata tells us where the corresponding SchemasTable should be. The // "fixed" flag, if true, claims that all events sent to this stage will have exactly // the same schema. When it is fixed, it need only be computed once at startup and, // if the table rolls every N days, at the beginning of each N day period. BuildSchema::BuildSchema(MdsdConfig *config, const MdsEntityName &target, bool fixed) : _target(target), _schemaIsFixed(fixed), _schemaRequired(false), _lastFullName() { // In order to upload MDS schema metadata, we must use the target's credentials to write to // an arbitrary table. Local and File table have no credentials at all. const Credentials *creds = target.GetCredentials(); if (creds && creds->accessAnyTable()) { // We need to write the schema. All we need to do is get a Batch pointer to which we // can write the SchemasTable entry. _schemaRequired = true; MdsEntityName schemaTarget { config, creds }; _batch = config->GetBatch(schemaTarget, 60); _moniker = creds->Moniker(); _agentIdentity = config->AgentIdentity(); } } void BuildSchema::Process(CanonicalEntity *item) { if (item && _schemaRequired) { Trace trace(Trace::XTable, "Pipe::BuildSchema::Process"); // This preamble does its best to bail out of schema writing as early and cheaply // as possible. We're silent from a tracing standpoint when taking the bailouts. std::string fullName = _target.Name(); if (_schemaIsFixed && (fullName == _lastFullName)) { // Schema is constant, and we've already written it for this tablename // (Example: schemas defined by ) // State for this is managed below goto done; } // Construct the key used to see if we've pushed this schema already auto metadata = MdsSchemaMetadata::GetOrMake(_target, item); if (!metadata) { goto done; } std::string key = _moniker + "/" + fullName + "/" + metadata->GetMD5(); if (_pushedSchemas.count(key)) { // We've already written it for this schema and tablename // (Example: schema computed from an OMI reply and written to a 10day table) goto done; } // OK, push the metadata and record it CanonicalEntity schemaCE { 12 }; std::string physicalTableName = _target.PhysicalTableName(); std::string rowkey = physicalTableName + "___" + metadata->GetMD5(); std::string N = MdsdUtil::ZeroFill(physicalTableName.size() % 10, 19); std::string pkey = N + "___" + MdsdUtil::ZeroFill(MdsTime::FakeTimeStampTicks, 19); trace.NOTE("Schema row: pkey " + pkey + " rowkey " + rowkey); utility::datetime timestamp1601; timestamp1601 = timestamp1601 + 1; schemaCE.AddColumn("PartitionKey", pkey); schemaCE.AddColumn("RowKey", rowkey); schemaCE.AddColumn("TIMESTAMP", new MdsValue(timestamp1601)); schemaCE.AddColumn("N", N); schemaCE.AddColumn("PhysicalTableName", physicalTableName); schemaCE.AddColumn("MD5Hash", metadata->GetMD5()); schemaCE.AddColumn("Schema", metadata->GetXML()); schemaCE.AddColumn("Uploader", _agentIdentity); schemaCE.AddColumn("UploadTS", new MdsValue(MdsTime::Now())); schemaCE.AddColumn("Reserved1", ""); schemaCE.AddColumn("Reserved2", ""); schemaCE.AddColumn("Reserved3", ""); _batch->AddRow(schemaCE); _pushedSchemas.insert(key); // If the input to this pipeline is always the same (i.e. fixed schema), then // we only have to do this once (or, perhaps, once every N days). if (_schemaIsFixed) { // Manage the state required by the bailout-early preamble if (_target.IsConstant()) { _schemaRequired = false; // Never have to do it again } else { _lastFullName = fullName; } } trace.NOTE("Finished; passing item to next stage"); } done: PipeStage::Process(item); } // End of namespace } // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/PipeStages.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _PIPESTAGES_HH_ #define _PIPESTAGES_HH_ #include "Pipeline.hh" #include "IdentityColumns.hh" #include "MdsEntityName.hh" #include "RowIndex.hh" #include #include #include class Batch; // Used by Pipe::Unpivot to implement transforms struct ColumnTransform { public: std::string Name; double Scale; ColumnTransform(std::string name, double scale = 1.0) : Name(name), Scale(scale) {} }; // Pipe stages must implement the Process method. // Pipe stages that retain data must implement the Done method. // Pipe stages must implement a constructor, which can have any parameters that might be required. namespace Pipe { class Unpivot : public PipeStage { public: Unpivot(const std::string &valueName, const std::string &nameName, const std::string &columns, std::unordered_map&& transforms); // ~Unpivot() {} void Process(CanonicalEntity *); const std::string& Name() const { return _name; } private: static const std::string _name; const std::string _valueName; const std::string _nameName; std::unordered_set _columns; std::unordered_map _transforms; }; // BatchWriter class is the final stage in a pipe and is responsible for getting the CanonicalEntity ready for // consumption by the sink that lies behind the batch. The principal task is managing the PartitionKey and // RowKey that are needed by some, but not all, sinks. // If Start() is called, then Done() must be called. If no call Start() is made, there's no need to call Done. // StoreType::XTable expects Start/Done pairs so it can correctly generate partition keys. class BatchWriter : public PipeStage { public: BatchWriter(Batch * b, const ident_vect_t *idvec, unsigned int pcount, StoreType::Type storeType); void Process(CanonicalEntity *); const std::string& Name() const { return _name; } void Start(const MdsTime QIBase) { _qibase = QIBase; } void AddSuccessor(PipeStage *) { throw std::logic_error("BatchWriter stage may not have a successor stage"); } void Done(); private: static const std::string _name; Batch *_batch; const ident_vect_t * _idvec; std::string _identString; StoreType::Type _storeType; std::string _Nstr; MdsTime _qibase; }; // Add "Identity" columns to the CanonicalEntity and pass it along class Identity : public PipeStage { public: Identity(const ident_vect_t * idvec) : _idvec(idvec) {} void Process(CanonicalEntity *); const std::string& Name() const { return _name; } private: static const std::string _name; const ident_vect_t * _idvec; }; // Build the MDS server-side schema based on the CanonicalEntity. If the combination of schema and full table name // (with NDay suffix as appropriate) has not yet been pushed to the matching SchemasTable, arrange for that to happen. class BuildSchema : public PipeStage { public: BuildSchema(MdsdConfig *config, const MdsEntityName &target, bool schemaIsFixed); void Process(CanonicalEntity *); const std::string& Name() const { return _name; } private: static const std::string _name; const MdsEntityName _target; bool _schemaIsFixed; bool _schemaRequired; std::string _lastFullName; std::string _moniker; std::string _agentIdentity; Batch* _batch; static std::unordered_set _pushedSchemas; }; } #endif // _PIPESTAGES_HH_ // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Pipeline.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Pipeline.hh" PipeStage::PipeStage() : _next(nullptr) { } PipeStage::~PipeStage() { if (_next) { delete _next; _next = nullptr; } } void PipeStage::AddSuccessor(PipeStage *next) { _next = next; } void PipeStage::Start(const MdsTime QIBase) { if (_next) { _next->Start(QIBase); } } void PipeStage::Process(CanonicalEntity *item) { if (item != nullptr) { if (_next) { _next->Process(item); } else { // Drop on floor; leak the memory, if any. } } } void PipeStage::Done() { if (_next) { _next->Done(); } } // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Pipeline.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _PIPELINE_HH_ #define _PIPELINE_HH_ #include "MdsTime.hh" #include class CanonicalEntity; // You must override Name() // If you override Start() you must finish by calling PipeStage::Start(QIbase) // If you override Process() and want to pass an item down the pipe, use PipeStage::Process(item) // If you override Done() you may call PipeStage::Process(item) and must finish by calling PipeStage::Done() class PipeStage { public: virtual ~PipeStage(); virtual void AddSuccessor(PipeStage *next); virtual void Start(const MdsTime QIbase); virtual void Process(CanonicalEntity *); virtual const std::string& Name() const = 0; virtual void Done(); protected: PipeStage(); private: PipeStage *_next; }; #endif // _PIPELINE_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/PoolMgmt.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _POOLMGMT_HH_ #define _POOLMGMT_HH_ #include #include #include #include class PoolMgmt { public: typedef std::basic_string, boost::fast_pool_allocator> PoolString; struct PoolStringHasher { std::size_t operator()(const PoolString& k) const { size_t h = std::hash()(k.data()); return h; } }; struct PoolStringEqualTo { bool operator()(const PoolString& p1, const PoolString& p2) const { return (0 == std::strcmp(p1.data(), p2.data())); } }; // This will release all memory blocks that aren’t used at the moment. // The memory will be returned to OS. static void ReleaseMemory() { boost::singleton_pool::release_memory(); boost::singleton_pool::release_memory(); boost::singleton_pool::release_memory(); } // This will release all memory blocks including those currently used. // The memory will be returned to OS. static void PurgeMemory() { boost::singleton_pool::purge_memory(); boost::singleton_pool::purge_memory(); boost::singleton_pool::purge_memory(); } }; typedef std::unordered_set> PoolStringUnorderedSet; #endif // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Priority.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Priority.hh" #include #include "Utility.hh" static std::map priorityMap { { "high", 60 }, { "medium", 300 }, { "normal", 300 }, { "default", 300 }, { "low", 900 } }; Priority::Priority(const std::string & name) { const auto &iter = priorityMap.find(MdsdUtil::to_lower(name)); if (iter == priorityMap.end()) { _duration = priorityMap["default"]; } else { _duration = iter->second; } } bool Priority::Set(const std::string & name) { const auto &iter = priorityMap.find(MdsdUtil::to_lower(name)); if (iter == priorityMap.end()) { return false; } _duration = iter->second; return true; } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Priority.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _PRIORITY_HH_ #define _PRIORITY_HH_ #include #include class Priority { public: Priority(const std::string & name); Priority() : _duration(300) {} ~Priority() {} bool Set(const std::string & name); time_t Duration() const { return _duration; } private: time_t _duration; }; #endif // _PRIORITY_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerBase.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolHandlerBase.hh" std::mutex ProtocolHandlerBase::_static_lock; std::unordered_map ProtocolHandlerBase::_key_id_map; SchemaCache::IdType ProtocolHandlerBase::schema_id_for_key(const std::string& key) { std::lock_guard lock(_static_lock); auto it = _key_id_map.find(key); if (it != _key_id_map.end()) { return it->second; } else { auto id = SchemaCache::Get().GetId(); _key_id_map.insert(std::make_pair(key, id)); return id; } } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerBase.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_HANDLER_BASE_HH #define _PROTOCOL_HANDLER_BASE_HH #include "SchemaCache.hh" #include #include /* * This class exists to eliminate duplicate code shared by the ProtocolHandler classes. * * The subclasses (e.g. ProtocolHandlerBond) are not, nor intended to be, thread safe. * The ProtocolListener classes allocate a separate instance per connection where each connection * has a separate thread. */ class ProtocolHandlerBase { protected: virtual ~ProtocolHandlerBase() = default; std::unordered_map _id_map; static SchemaCache::IdType schema_id_for_key(const std::string& key); static std::mutex _static_lock; static std::unordered_map _key_id_map; }; #endif //_PROTOCOL_HANDLER_BASE_HH ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerBond.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolHandlerBond.hh" #include #include "Logger.hh" #include "MdsValue.hh" #include "CanonicalEntity.hh" #include "Trace.hh" #include "LocalSink.hh" #include "Utility.hh" extern "C" { #include } ProtocolHandlerBond::~ProtocolHandlerBond() { Trace trace(Trace::EventIngest, "ProtocolHandlerBond::Destructor"); close(_fd); Logger::LogInfo(std::string("ProtocolHandlerBond: Connection on ") + std::to_string(_fd) + " closed"); } void ProtocolHandlerBond::Run() { Trace trace(Trace::EventIngest, "ProtocolHandlerBond::Run"); while(true) { try { mdsdinput::Message msg; mdsdinput::Ack ack; // Read message _io.ReadMessage(msg); // Process message ack.msgId = msg.msgId; ack.code = handleEvent(msg); // Ack message _io.WriteAck(ack); } catch (mdsdinput::eof_exception) { Logger::LogInfo(std::string("ProtocolHandlerBond: EOF on ") + std::to_string(_fd)); return; } catch (mdsdinput::msg_too_large_error) { Logger::LogWarn(std::string("ProtocolHandlerBond: Received oversized message on ") + std::to_string(_fd)); return; } catch (std::exception& ex) { Logger::LogError(std::string("ProtocolHandlerBond: Unexpected exception while processing messages on ") + std::to_string(_fd) + ": " + ex.what()); return; } } } class FieldReceiver { public: FieldReceiver(CanonicalEntity& ce) : _ce(ce) {} void BoolField(const std::string& name, bool value) { _ce.AddColumnIgnoreMetaData(name, new MdsValue(value)); } void Int32Field(const std::string& name, int32_t value) { _ce.AddColumnIgnoreMetaData(name, new MdsValue(static_cast(value))); } void Int64Field(const std::string& name, int64_t value) { // The explicit cast is necessary. Without it, the value will get treated as mt_int32. _ce.AddColumnIgnoreMetaData(name, new MdsValue(static_cast(value))); } void DoubleField(const std::string& name, double value) { _ce.AddColumnIgnoreMetaData(name, new MdsValue(value)); } void TimeField(const std::string& name, const mdsdinput::Time& value, bool isTimestampField) { MdsTime time(value.sec, value.nsec/1000); _ce.AddColumnIgnoreMetaData(name, new MdsValue(time)); if (isTimestampField) { _ce.SetPreciseTime(time); } } void StringField(const std::string& name, const std::string& value) { _ce.AddColumnIgnoreMetaData(name, new MdsValue(value)); } private: CanonicalEntity& _ce; }; mdsdinput::ResponseCode ProtocolHandlerBond::handleEvent(const mdsdinput::Message& msg) { Trace trace(Trace::EventIngest, "ProtocolHandlerBond::handleEvent"); TRACEINFO(trace, "Received msg {MsgId: " << msg.msgId << ", Source: " << msg.source << "}"); auto sink = LocalSink::Lookup(msg.source); if (!sink) { Logger::LogWarn("Received an event from source \"" + msg.source + "\" not used elsewhere in the active configuration"); return mdsdinput::ACK_INVALID_SOURCE; } // This check may be overly restrictive. // Perhaps we should allow it if the message's dynamically defined schema matches the predefined schema. if (sink->SchemaId() != 0) { Logger::LogWarn("ProtocolHandlerBond: Received an event from source \"" + msg.source + "\" that is not valid for dynamic schema input"); return mdsdinput::ACK_INVALID_SOURCE; } auto ce = std::make_shared(); ce->SetPreciseTime(MdsTime::Now()); FieldReceiver fr(*ce); auto responseCode = _decoder.Decode(msg, fr); if (mdsdinput::ACK_SUCCESS == responseCode) { SchemaCache::IdType schemaId; auto it = _id_map.find(msg.schemaId); if (it != _id_map.end()) { schemaId = it->second; } else { schemaId = schema_id_for_key(_decoder.GetSchemaKey(msg.schemaId)); _id_map.insert(std::make_pair(msg.schemaId, schemaId)); TRACEINFO(trace, "Mapped connection schemaId (" << msg.schemaId << ") to SchemaCache id (" << schemaId << ")"); } ce->SetSchemaId(schemaId); TRACEINFO(trace, "Message added to LocalSink with schemaId (" << schemaId << ")"); sink->AddRow(ce); } return responseCode; } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerBond.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_HANDLER_BOND_HH_ #define _PROTOCOL_HANDLER_BOND_HH_ #include "ProtocolHandlerBase.hh" #include "MdsdInputMessageIO.h" #include "MdsdInputMessageDecoder.h" /* * This class is not, nor is it intended to be, thread safe. * * ProtocolListenerBond allocates a separate instance of this class per connection * and each connection is handled by a separate thread. */ class ProtocolHandlerBond: public ProtocolHandlerBase { public: explicit ProtocolHandlerBond(int fd) : _fd(fd), _fdio(fd), _io(_fdio) {} ~ProtocolHandlerBond(); void Run(); private: mdsdinput::ResponseCode handleEvent(const mdsdinput::Message& msg); int _fd; mdsdinput::FDIO _fdio; mdsdinput::MessageIO _io; mdsdinput::MessageDecoder _decoder; }; // vim: set ai sw=8: #endif // _PROTOCOL_HANDLER_BOND_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerJSON.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "ProtocolHandlerJSON.hh" #include "Logger.hh" #include "MdsValue.hh" #include "CanonicalEntity.hh" #include "Trace.hh" #include "LocalSink.hh" #include "Utility.hh" extern "C" { #include } ProtocolHandlerJSON::~ProtocolHandlerJSON() { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::Destructor"); close(_fd); Logger::LogInfo(std::string("ProtocolHandlerJSON: Connection on ") + std::to_string(_fd) + " closed"); } void ProtocolHandlerJSON::Run() { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::Run"); msg_data_t msg_data; while(true) { try { mdsdinput::Ack ack; // Read message size_t msg_size = readMsgSize(); readMsgData(msg_data, msg_size); // Process message ack = handleMsg(msg_data); // Ack message writeAck(ack.msgId, ack.code); } catch (mdsdinput::eof_exception) { Logger::LogInfo(std::string("ProtocolHandlerJSON: EOF on ") + std::to_string(_fd)); return; } catch (mdsdinput::msg_too_large_error) { Logger::LogWarn(std::string("ProtocolHandlerJSON: Received oversized message on ") + std::to_string(_fd)); return; } catch (std::exception& ex) { Logger::LogError(std::string("ProtocolHandlerJSON: Unexpected exception while processing messages on ") + std::to_string(_fd) + ": " + ex.what()); return; } } } size_t ProtocolHandlerJSON::readMsgSize() { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::readMsgSize"); char sbuf[8]; size_t sidx = 0; do { ssize_t n = read(_fd, &sbuf[sidx], 1); if (n < 0) { if (errno != EINTR) { throw std::system_error(errno, std::system_category()); } } else if (n == 0) { throw mdsdinput::eof_exception(); } else { if (sbuf[sidx] == '\n') { break; } sidx++; } } while (sidx < sizeof(sbuf)); if (sidx == sizeof(sbuf)) { throw mdsdinput::msg_too_large_error("ProtocolHandlerJSON: Message size string is too long"); } sbuf[sidx] = 0; size_t size = std::stoul(sbuf); if (size == 0 || size > MAX_MSG_DATA_SIZE) { throw std::runtime_error("Invalid message size"); } return size; } void ProtocolHandlerJSON::readMsgData(msg_data_t& msg_data, size_t size) { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::readMsgData"); char* ptr = &msg_data[0]; size_t idx = 0; do { ssize_t n = read(_fd, ptr, size - idx); if (n < 0) { if (errno != EINTR) { throw std::system_error(errno, std::system_category()); } } else if (n == 0) { throw mdsdinput::eof_exception(); } else { idx += n; ptr += n; } } while (idx < size); msg_data[size] = 0; } void ProtocolHandlerJSON::writeAck(uint64_t msgId, mdsdinput::ResponseCode rcode) { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::writeAck"); std::ostringstream out; out << msgId << ":" << rcode << std::endl; auto str = out.str(); ssize_t n = write(_fd, str.c_str(), str.size()); if (n < 0) { throw std::system_error(errno, std::system_category()); } else if (n < static_cast(str.size())) { throw mdsdinput::eof_exception(); } } mdsdinput::Ack ProtocolHandlerJSON::decodeMsg(msg_data_t& msg_data, std::string& source, CanonicalEntity& ce) { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::decodeMsg"); mdsdinput::Ack ack; rapidjson::Document d; d.ParseInsitu(&msg_data[0]); ack.code = mdsdinput::ACK_DECODE_ERROR; // Build/fetch schema if (!d.IsArray()) { throw std::runtime_error("Invalid JSON document: Was not an array"); } if (d.Size() != 5) { std::ostringstream msg; msg << "Invalid JSON document: Array size invalid: Expected 5, got " << d.Size(); throw std::runtime_error(msg.str()); } const rapidjson::Value& jsource = d[0]; const rapidjson::Value& jmsgId = d[1]; const rapidjson::Value& jschemaId = d[2]; const rapidjson::Value& jschema = d[3]; const rapidjson::Value& jmsgdata = d[4]; if (!jsource.IsString()) { throw std::runtime_error("Invalid JSON document: source (0) is not a String"); } if (!jmsgId.IsNumber()) { throw std::runtime_error("Invalid JSON document: msgId (1) is not a Number"); } if (!jschemaId.IsNumber()) { throw std::runtime_error("Invalid JSON document: schemaId (2) is not a Number"); } if (!jmsgdata.IsArray()) { throw std::runtime_error("Invalid JSON document: data (4) is not an Array"); } if (!jschema.IsNull() && !jschema.IsArray()) { throw std::runtime_error("Invalid JSON document: schema (3) is not an Array"); } auto schema_id = jschemaId.GetUint64(); std::shared_ptr schema; if (!jschema.IsNull()) { bool hasTimestampIndex = false; uint32_t timestampIndex; schema = std::make_shared(); for (rapidjson::Value::ConstValueIterator it = jschema.Begin(); it != jschema.End(); ++it) { if (it == jschema.Begin() && !it->IsArray()) { // If the first element of the array is not an array, not null, and is an unsigned integer // then use it as the timestamp index. if (!it->IsNull() && it->IsUint()) { hasTimestampIndex = true; timestampIndex = static_cast(it->GetUint64()); } } else { if (!it->IsArray() || it->Size() != 2) { throw std::runtime_error("Invalid Schema"); } const rapidjson::Value &name = (*it)[0]; const rapidjson::Value &ft = (*it)[1]; if (!name.IsString() || !ft.IsString()) { throw std::runtime_error("Invalid Schema"); } mdsdinput::FieldDef fd; fd.name = name.GetString(); if (!ToEnum(fd.fieldType, ft.GetString())) { throw std::runtime_error("Invalid Schema"); } schema->fields.push_back(fd); } } if (hasTimestampIndex) { if (timestampIndex < schema->fields.size()) { schema->timestampFieldIdx.set(timestampIndex); } } if (!_schema_cache->AddSchemaWithId(schema, schema_id)) { ack.code = mdsdinput::ACK_DUPLICATE_SCHEMA_ID; return ack; } } else { try { schema = _schema_cache->GetSchema(schema_id); } catch(std::out_of_range) { ack.code = mdsdinput::ACK_UNKNOWN_SCHEMA_ID; return ack; } } if (schema->fields.size() != jmsgdata.Size()) { std::ostringstream msg; msg << "Invalid message data: Array size invalid: Expected " << schema->fields.size() << ", got " << jmsgdata.Size(); throw std::runtime_error(msg.str()); } ack.msgId = jmsgId.GetInt64(); source = std::string(jsource.GetString(), jsource.GetStringLength()); // for (int i = 0; i < (int)schema->fields.size(); ++i) { mdsdinput::FieldDef fd = schema->fields.at(i); const rapidjson::Value& val = jmsgdata[i]; switch (fd.fieldType) { case mdsdinput::FT_INVALID: ack.code = mdsdinput::ACK_DECODE_ERROR; return ack; case mdsdinput::FT_BOOL: if (!val.IsBool()) { throw std::runtime_error("Invalid Message data"); } ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(val.GetBool())); break; case mdsdinput::FT_INT32: if (!val.IsInt()) { throw std::runtime_error("Invalid Message data"); } ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(static_cast(val.GetInt()))); break; case mdsdinput::FT_INT64: if (!val.IsInt64()) { throw std::runtime_error("Invalid Message data"); } // The explicit cast is necessary. Without it, the value will get treated as mt_int32. ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(static_cast(val.GetInt64()))); break; case mdsdinput::FT_DOUBLE: if (!val.IsNumber()) { throw std::runtime_error("Invalid Message data"); } ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(val.GetDouble())); break; case mdsdinput::FT_TIME: if (!val.IsArray() || val.Size() != 2) { throw std::runtime_error("Invalid Message data"); } { MdsTime time(val[0].GetUint64(), val[1].GetUint()/1000); ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(time)); if (!schema->timestampFieldIdx.empty() && static_cast(i) == *(schema->timestampFieldIdx)) { ce.SetPreciseTime(time); } } break; case mdsdinput::FT_STRING: if (!val.IsString()) { throw std::runtime_error("Invalid Message data"); } ce.AddColumnIgnoreMetaData(fd.name, new MdsValue(std::string(val.GetString(), val.GetStringLength()))); break; default: throw std::runtime_error("Invalid field type in schema"); } } SchemaCache::IdType mdsdSchemaId; auto it = _id_map.find(schema_id); if (it != _id_map.end()) { mdsdSchemaId = it->second; } else { mdsdSchemaId = schema_id_for_key(_schema_cache->GetSchemaKey(schema_id)); _id_map.insert(std::make_pair(schema_id, mdsdSchemaId)); TRACEINFO(trace, "Mapped connection schemaId ("+std::to_string(schema_id)+") to SchemaCache id ("+std::to_string(mdsdSchemaId)+")"); } ce.SetSchemaId(mdsdSchemaId); ack.code = mdsdinput::ACK_SUCCESS; return ack; } mdsdinput::Ack ProtocolHandlerJSON::handleMsg(msg_data_t& msg_data) { Trace trace(Trace::EventIngest, "ProtocolHandlerJSON::handleMsg"); mdsdinput::Ack ack; std::string source; auto ce = std::make_shared(); ce->SetPreciseTime(MdsTime::Now()); try { ack = decodeMsg(msg_data, source, *ce); } catch(std::exception& ex) { std::ostringstream strm; strm << "ProtocolHandlerJSON: Error decoding message '"; for(auto c: msg_data) { strm << c; } strm << "' from fd " << _fd << ": " << ex.what(); Logger::LogWarn(strm); ack.code = mdsdinput::ACK_DECODE_ERROR; } if (ack.code == mdsdinput::ACK_SUCCESS) { auto sink = LocalSink::Lookup(source); if (!sink) { Logger::LogWarn("ProtocolHandlerJSON: Received an event from source \"" + source + "\" not used elsewhere in the active configuration"); ack.code = mdsdinput::ACK_INVALID_SOURCE; } else { TRACEINFO(trace, "Message added to LocalSink"); sink->AddRow(ce); } } return ack; } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolHandlerJSON.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_HANDLER_JSON_HH_ #define _PROTOCOL_HANDLER_JSON_HH_ #include "ProtocolHandlerBase.hh" #include "CanonicalEntity.hh" #include "MdsdInputMessageIO.h" #include "MdsdInputMessageDecoder.h" #include "MdsdInputSchemaCache.h" #include #include "rapidjson/document.h" #include "rapidjson/stringbuffer.h" /* * This class is not, nor is it intended to be, thread safe. * * ProtocolListenerDynamicJSON allocates a separate instance of this class per connection * and each connection is handled by a separate thread. */ class ProtocolHandlerJSON: public ProtocolHandlerBase { public: static constexpr size_t MAX_MSG_DATA_SIZE = 128 * 1024-1; typedef std::array msg_data_t; explicit ProtocolHandlerJSON(int fd) : _fd(fd), _schema_cache(std::make_shared()) {} ~ProtocolHandlerJSON(); void Run(); private: size_t readMsgSize(); void readMsgData(msg_data_t& msg_data, size_t size); void writeAck(uint64_t msgId, mdsdinput::ResponseCode rcode); mdsdinput::Ack decodeMsg(msg_data_t& msg_data, std::string& source, CanonicalEntity& ce); mdsdinput::Ack handleMsg(msg_data_t& msg_data); int _fd; std::shared_ptr _schema_cache; }; // vim: set ai sw=8: #endif // _PROTOCOL_HANDLER_JSON_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListener.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListener.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" #include #include #include #include extern "C" { #include #include #include #include #include extern void StopProtocolListenerMgr(); } ProtocolListener::~ProtocolListener() { Trace trace(Trace::EventIngest, "ProtocolListener::handleEvent"); Stop(); } void ProtocolListener::openListener() { Trace trace(Trace::EventIngest, "ProtocolListener::openListener"); int fd = MdsdUtil::CreateAndBindUnixSocket(_file_path); // Allow processes under non-root UID (e.g., rsyslogd on Ubuntu) to send msg to this mode_t mode = 0666; if (-1 == chmod(_file_path.c_str(), mode)) { close(fd); throw std::system_error(errno, std::system_category(), "chmod(" + _file_path + ", " + std::to_string(mode) + ")"); } _listenfd = fd; } void ProtocolListener::Start() { Trace trace(Trace::EventIngest, "ProtocolListener::Start"); std::unique_lock lock(_lock); if (_thread.get_id() != std::thread::id()) { return; } openListener(); if (listen(_listenfd, 10)) { throw std::system_error(errno, std::system_category(), "listen()"); } std::thread thread([this](){this->run();}); _thread.swap(thread); } void ProtocolListener::Stop() { Trace trace(Trace::EventIngest, "ProtocolListener::Stop"); std::unique_lock lock(_lock); if (_listenfd != -1) { close(_listenfd); _listenfd = -1; _thread.detach(); } } bool ProtocolListener::stopCheck() { std::lock_guard lock(_lock); return _listenfd == -1 || std::this_thread::get_id() != _thread.get_id(); } void ProtocolListener::run() { Trace trace(Trace::EventIngest, "ProtocolListener::run"); int lfd; { std::lock_guard lock(_lock); lfd = _listenfd; } while(!stopCheck()) { struct pollfd fds[1]; fds[0].fd = lfd; fds[0].events = POLLIN; fds[0].revents = 0; int r = poll(&fds[0], 1, 1000); if (r < 0) { if (errno == EINTR) { continue; } if (!stopCheck()) { // Log all other errors and return. Logger::LogError(std::string("ProtocolListener(" + _protocol + "): poll() returned an unexpected error: ") + std::strerror(errno)); // Initiate a clean process exit. StopProtocolListenerMgr(); // After calling StopProtocolListenerMgr() the only safe thing to do is return. } return; } if (r == 1) { int newfd = accept(lfd, NULL, 0); if (newfd > 0) { if (!stopCheck()) { HandleConnection(newfd); } else { close(newfd); } } else { // If accept was interrupted, or the connection was reset (RST) // before it could be accepted, then just continue. if (errno == EINTR || errno == ECONNABORTED) { continue; } // If the per-process (EMFILE) or system (ENFILE) descriptor limit is reached // then sleep for a while in the hope that the situation will improve. if (errno == EMFILE || errno == ENFILE) { Logger::LogError(std::string("ProtocolListener(") + _protocol + "): descriptor limit reached: " + std::strerror(errno)); Logger::LogWarn(std::string("ProtocolListener(" + _protocol + "): waiting 1 minute before trying to accept new connections")); std::this_thread::sleep_for(std::chrono::seconds(60)); continue; } if (!stopCheck()) { // Log all other errors and return. Logger::LogError(std::string("ProtocolListener(" + _protocol + "): accept() returned an unexpected error: ") + std::strerror(errno)); // Other errors indicate (probably) fatal conditions. // Initiate a clean process exit. StopProtocolListenerMgr(); // After calling StopProtocolListenerMgr() the only safe thing to do is return. } return; } } } } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListener.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_HH_ #define _PROTOCOL_LISTENER_HH_ #include #include #include class ProtocolListener { public: virtual ~ProtocolListener(); std::string Protocol() { return _protocol; } void Start(); void Stop(); std::string FilePath() { return _file_path; }; protected: ProtocolListener(const std::string& prefix, const std::string& protocol) : _prefix(prefix), _protocol(protocol), _listenfd(-1) { _file_path = _prefix + "_" + _protocol + ".socket"; } virtual void openListener(); virtual void HandleConnection(int fd) = 0; std::string _prefix; std::string _protocol; std::string _file_path; std::mutex _lock; int _listenfd; std::thread _thread; bool stopCheck(); void run(); }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerBond.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListenerBond.hh" #include "ProtocolHandlerBond.hh" #include "Logger.hh" #include "Trace.hh" #include static void handler(int fd) { ProtocolHandlerBond h(fd); h.Run(); } void ProtocolListenerBond::HandleConnection(int fd) { Trace trace(Trace::EventIngest, "ProtocolListenerBond::HandleConnection"); std::thread thread(handler, fd); std::ostringstream out; out << "ProtocolListenerBond: Created BOND thread " << thread.get_id() << " for fd " << fd; thread.detach(); Logger::LogInfo(out.str()); } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerBond.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_BOND_HH_ #define _PROTOCOL_LISTENER_BOND_HH_ #include "ProtocolListener.hh" class ProtocolListenerBond : public ProtocolListener { public: ProtocolListenerBond(const std::string& prefix) : ProtocolListener(prefix, "bond") {} protected: virtual void HandleConnection(int fd); }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_BOND_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerDynamicJSON.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListenerDynamicJSON.hh" #include "ProtocolHandlerJSON.hh" #include "Logger.hh" #include "Trace.hh" static void handler(int fd) { ProtocolHandlerJSON h(fd); h.Run(); } void ProtocolListenerDynamicJSON::HandleConnection(int fd) { Trace trace(Trace::EventIngest, "ProtocolListenerDynamicJSON::HandleConnection"); std::thread thread(handler, fd); std::ostringstream out; out << "ProtocolListenerDynamicJSON: Created Dynamic JSON thread " << thread.get_id() << " for fd " << fd; Logger::LogInfo(out.str()); thread.detach(); } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerDynamicJSON.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_DYNAMIC_JSON_HH_ #define _PROTOCOL_LISTENER_DYNAMIC_JSON_HH_ #include "ProtocolListener.hh" class ProtocolListenerDynamicJSON : public ProtocolListener { public: ProtocolListenerDynamicJSON(const std::string& prefix) : ProtocolListener(prefix, "djson") {} protected: virtual void HandleConnection(int fd); }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_DYNAMIC_JSON_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerJSON.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListenerJSON.hh" #include "StreamListener.hh" #include "Logger.hh" #include "Trace.hh" static void handler(int fd) { StreamListener::handler(new StreamListener(fd)); } void ProtocolListenerJSON::HandleConnection(int fd) { Trace trace(Trace::EventIngest, "ProtocolListenerJSON::HandleConnection"); std::thread thread(handler, fd); std::ostringstream out; out << "ProtocolListenerJSON: Created JSON thread " << thread.get_id() << " for fd " << fd; Logger::LogInfo(out.str()); thread.detach(); } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerJSON.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_JSON_HH_ #define _PROTOCOL_LISTENER_JSON_HH_ #include "ProtocolListener.hh" class ProtocolListenerJSON : public ProtocolListener { public: ProtocolListenerJSON(const std::string& prefix) : ProtocolListener(prefix, "json") {} protected: virtual void HandleConnection(int fd); }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_JSON_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerMgr.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListenerMgr.hh" #include "ProtocolListenerDynamicJSON.hh" #include "ProtocolListenerJSON.hh" #include "ProtocolListenerTcpJSON.hh" #include "ProtocolListenerBond.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" extern "C" { #include } static MdsdUtil::LockedFile pidPortFile; ProtocolListenerMgr::~ProtocolListenerMgr() { Trace trace(Trace::EventIngest, "ProtocolListenerMgr::Destructor"); } void ProtocolListenerMgr::Init(const std::string& prefix, int port, bool retry_random) { Trace trace(Trace::EventIngest, "ProtocolListenerMgr::Init"); TRACEINFO(trace, "Prefix: " + prefix + ", Port: " + std::to_string(port)); if (nullptr == _mgr) { _mgr = new ProtocolListenerMgr(prefix, port, retry_random); } } ProtocolListenerMgr* ProtocolListenerMgr::_mgr = nullptr; ProtocolListenerMgr* ProtocolListenerMgr::GetProtocolListenerMgr() { return _mgr; } bool ProtocolListenerMgr::Start() { Trace trace(Trace::EventIngest, "ProtocolListenerMgr::Start"); std::unique_lock lock(_lock); if (_stop) { bool failed = false; _stop = false; pidPortFile.Open(_prefix + ".pidport"); pidPortFile.WriteLine(std::to_string(getpid())); _bond_listener.reset(new ProtocolListenerBond(_prefix)); _djson_listener.reset(new ProtocolListenerDynamicJSON(_prefix)); _json_listener.reset(new ProtocolListenerJSON(_prefix)); _tcp_json_listener.reset(new ProtocolListenerTcpJSON(_prefix, _port, _retry_random)); try { _bond_listener->Start(); } catch (std::system_error& ex) { _bond_listener.release(); Logger::LogError(std::string("ProtocolListenerMgr: BOND Listener failed to start: ") + ex.what()); failed = true; } if (!failed) { try { _djson_listener->Start(); } catch (std::system_error &ex) { _djson_listener.release(); Logger::LogError(std::string("ProtocolListenerMgr: Dynamic JSON Listener failed to start: ") + ex.what()); failed = true; } } if (!failed) { try { _json_listener->Start(); } catch (std::system_error &ex) { _json_listener.release(); Logger::LogError(std::string("ProtocolListenerMgr: JSON Listener failed to start: ") + ex.what()); failed = true; } } if (!failed) { try { _tcp_json_listener->Start(); pidPortFile.WriteLine(std::to_string(static_cast(_tcp_json_listener.get())->Port())); } catch (std::system_error &ex) { _tcp_json_listener.release(); Logger::LogError(std::string("ProtocolListenerMgr: TCP JSON Listener failed to start: ") + ex.what()); failed = true; } } // One of the listeners failed to start. Stop the manager so things get cleaned up before process exit. if (failed) { _lock.unlock(); Stop(); return false; } } return true; } void ProtocolListenerMgr::Stop() { Trace trace(Trace::EventIngest, "ProtocolListenerMgr::Stop"); std::lock_guard lock(_lock); if (!_stop) { try { if (_bond_listener) { _bond_listener->Stop(); unlink(_bond_listener->FilePath().c_str()); _bond_listener.release(); } if (_djson_listener) { _djson_listener->Stop(); unlink(_djson_listener->FilePath().c_str()); _djson_listener.release(); } if (_json_listener) { _json_listener->Stop(); unlink(_json_listener->FilePath().c_str()); _json_listener.release(); } if (_tcp_json_listener) { _tcp_json_listener->Stop(); _tcp_json_listener.release(); } } catch(std::exception& ex) { Logger::LogError("Error: ProtocolListenerMgr::Stop() unexpected exception while stopping listeners: " + std::string(ex.what())); } catch(...) { Logger::LogError("Error: ProtocolListenerMgr::Stop() unknown exception while stopping listeners."); } try { pidPortFile.Remove(); } catch(std::exception& ex) { Logger::LogError("Error: ProtocolListenerMgr::Stop() unexpected exception while trying to remove pid-port file: " + std::string(ex.what())); } catch(...) { Logger::LogError("Error: ProtocolListenerMgr::Stop() unknown exception while trying to remove pid-port file."); } _stop = true; _cond.notify_all(); } } void ProtocolListenerMgr::Wait() { Trace trace(Trace::EventIngest, "ProtocolListenerMgr::Wait"); std::unique_lock lock(_lock); // Wait for stop _cond.wait(lock, [this]{return this->_stop;}); } extern "C" void StopProtocolListenerMgr() { auto plmgmt = ProtocolListenerMgr::GetProtocolListenerMgr(); if (plmgmt != nullptr) { plmgmt->Stop(); } } extern "C" void TruncateAndClosePidPortFile() { try { pidPortFile.TruncateAndClose(); } catch(std::exception& ex) { Logger::LogError("Error: TruncateAndClosePidPortFile() unexpected exception: " + std::string(ex.what())); } catch(...) { Logger::LogError("Error: TruncateAndClosePidPortFile() unknown exception."); } } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerMgr.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_MGMT_HH_ #define _PROTOCOL_LISTENER_MGMT_HH_ #include "ProtocolListener.hh" #include #include #include #include class ProtocolListenerMgr { public: ~ProtocolListenerMgr(); static void Init(const std::string& prefix, int port, bool retry_random); static ProtocolListenerMgr* GetProtocolListenerMgr(); bool Start(); void Stop(); void Wait(); private: ProtocolListenerMgr(const std::string& prefix, int port, bool retry_random) : _prefix(prefix), _port(port), _retry_random(retry_random), _stop(true) {} static ProtocolListenerMgr* _mgr; std::string _prefix; int _port; bool _retry_random; std::mutex _lock; std::condition_variable _cond; bool _stop; std::unique_ptr _bond_listener; std::unique_ptr _djson_listener; std::unique_ptr _json_listener; std::unique_ptr _tcp_json_listener; }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_MGMT_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerTcpJSON.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "ProtocolListenerTcpJSON.hh" #include "StreamListener.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" static void handler(int fd) { StreamListener::handler(new StreamListener(fd)); } void ProtocolListenerTcpJSON::openListener() { Trace trace(Trace::EventIngest, "ProtocolListenerTcpJSON::openListener"); int fd = socket(AF_INET, SOCK_STREAM, 0); if (-1 == fd) { throw std::system_error(errno, std::system_category(), "socket(AF_INET, SOCK_STREAM)"); } MdsdUtil::FdCloser fdCloser(fd); int reuseaddr = 1; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuseaddr, sizeof(reuseaddr))) { throw std::system_error(errno, std::system_category(), "setsockopt(SO_REUSEADDR)"); } struct { int l_onoff; int l_linger; } linger { 0, 0 }; if (setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger))) { throw std::system_error(errno, std::system_category(), "setsockopt(SO_LINGER)"); } if (_port == 0) { Logger::LogInfo(std::string("ProtocolListenerTcpJSON: Binding to a random port")); } struct sockaddr_in loopback; loopback.sin_family = AF_INET; loopback.sin_port = htons(_port); loopback.sin_addr.s_addr = htonl(INADDR_LOOPBACK); if (bind(fd, (struct sockaddr *)&loopback, sizeof(loopback))) { // If the first bind attempt was to a random port, then it doesn't matter what the errno is. // Just throw the exception. Trying, again, on a random port is also likely to fail. if (errno != EADDRINUSE || !_retry_random || _port == 0) { throw std::system_error(errno, std::system_category(), std::string("bind(AF_INET, ") + std::to_string(_port) + ")"); } Logger::LogWarn("ProtocolListenerTcpJSON: Port " + std::to_string(_port) + " is already in use. Will try a random port."); loopback.sin_port = 0; if (bind(fd, (struct sockaddr *) &loopback, sizeof(loopback))) { throw std::system_error(errno, std::system_category(), std::string("bind(AF_INET, 0)")); } } socklen_t len = sizeof(loopback); if (getsockname(fd, (struct sockaddr*)&loopback, &len)) { throw std::system_error(errno, std::system_category(), "getsockname()"); } auto _requested_port = _port; _port = (int)ntohs(loopback.sin_port); if (_requested_port != _port) { Logger::LogWarn(std::string("ProtocolListenerTcpJSON: Listener port is ") + std::to_string(_port)); } fdCloser.Release(); _listenfd = fd; } void ProtocolListenerTcpJSON::HandleConnection(int fd) { Trace trace(Trace::EventIngest, "ProtocolListenerTcpJSON::HandleConnection"); std::thread thread(handler, fd); std::ostringstream out; out << "ProtocolListenerTcpJSON: Created TCP JSON thread " << thread.get_id() << " for fd " << fd; thread.detach(); Logger::LogInfo(out.str()); } ================================================ FILE: Diagnostic/mdsd/mdsd/ProtocolListenerTcpJSON.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _PROTOCOL_LISTENER_TCP_JSON_HH_ #define _PROTOCOL_LISTENER_TCP_JSON_HH_ #include "ProtocolListener.hh" class ProtocolListenerTcpJSON : public ProtocolListener { public: ProtocolListenerTcpJSON(const std::string& prefix, int port, bool retry_random) : ProtocolListener(prefix, "json"), _port(port), _retry_random(retry_random) { _file_path = _prefix + ".pidport"; } int Port() { return _port; }; protected: virtual void openListener(); virtual void HandleConnection(int fd); int _port; bool _retry_random; }; // vim: set ai sw=8: #endif // _PROTOCOL_LISTENER_TCP_JSON_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/RowIndex.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "RowIndex.hh" #include #include #include thread_local unsigned long long RowIndex::_index = ULLONG_MAX; unsigned long long RowIndex::_baseValue = 0; std::mutex RowIndex::_mutex; unsigned long long RowIndex::Get() { if (_index == ULLONG_MAX) { unsigned long long now = (((unsigned long long) time(0)) & 0xfffff) << 32; std::lock_guard lock(_mutex); _index = _baseValue + now; _baseValue += 1ULL << 54; } return _index++; } ================================================ FILE: Diagnostic/mdsd/mdsd/RowIndex.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _ROWINDEX_HH_ #define _ROWINDEX_HH_ #include class RowIndex { public: static unsigned long long Get(); private: static thread_local unsigned long long _index; static unsigned long long _baseValue; static std::mutex _mutex; RowIndex(); }; #endif //_ROWINDEX_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/SaxParserBase.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "SaxParserBase.hh" #include #include #include extern "C" { #include #include } /////////////////////////////////////////////////////////////////////// // SAX callback dispatchers that are registered for every SaxParserBase // instance, which will call the actual callbacks in the instance. static void OnStartDocumentCallback(void* userData) { auto parser = static_cast(userData); assert(nullptr != parser); parser->OnStartDocument(); } static void OnEndDocumentCallback(void* userData) { auto parser = static_cast(userData); assert(nullptr != parser); parser->OnEndDocument(); } static void OnCommentCallback(void* userData, const xmlChar* comment) { auto parser = static_cast(userData); assert(nullptr != parser); const std::string commentStr((comment == nullptr) ? "" : reinterpret_cast(comment)); parser->OnComment(commentStr); } static void OnStartElementCallback( void* userData, const xmlChar* localname, const xmlChar** attributes ) { auto parser = static_cast(userData); assert(nullptr != parser); std::string name(reinterpret_cast(localname)); SaxParserBase::AttributeMap attrs; while (attributes != nullptr && *attributes != nullptr) { auto key = reinterpret_cast(attributes[0]); auto value = reinterpret_cast(attributes[1]); auto retval = attrs.emplace(key, value); if (!retval.second) { std::ostringstream oss; oss << "An extra instance of attribute \"" << key << "\" in element \"" << name << "\" was seen and ignored"; parser->OnWarning(oss.str()); } attributes += 2; } parser->OnStartElement(name, attrs); } static void OnCharactersCallback( void* userData, const xmlChar* chars, int len ) { auto parser = static_cast(userData); assert(nullptr != parser); const std::string charsStr(reinterpret_cast(chars), static_cast(len)); parser->OnCharacters(charsStr); } static void OnEndElementCallback( void* userData, const xmlChar* localname ) { auto parser = static_cast(userData); assert(nullptr != parser); const std::string name(reinterpret_cast(localname)); parser->OnEndElement(name); } static constexpr size_t MESSAGE_BUFFER_SIZE = 512; static void OnWarningCallback(void* userData, const char* msg, ...) { auto parser = static_cast(userData); assert(nullptr != parser); char buf[MESSAGE_BUFFER_SIZE]; va_list arglist; va_start(arglist, msg); vsnprintf(buf, MESSAGE_BUFFER_SIZE, msg, arglist); va_end(arglist); const std::string warning(buf); parser->OnWarning(warning); } static void OnErrorCallback(void* userData, const char* msg, ...) { auto parser = static_cast(userData); assert(nullptr != parser); char buf[MESSAGE_BUFFER_SIZE]; va_list arglist; va_start(arglist, msg); vsnprintf(buf, MESSAGE_BUFFER_SIZE, msg, arglist); va_end(arglist); const std::string error(buf); parser->OnError(error); } static void OnFatalErrorCallback(void* userData, const char* msg, ...) { auto parser = static_cast(userData); assert(nullptr != parser); char buf[MESSAGE_BUFFER_SIZE]; va_list arglist; va_start(arglist, msg); vsnprintf(buf, MESSAGE_BUFFER_SIZE, msg, arglist); va_end(arglist); const std::string fatalError(buf); parser->OnFatalError(fatalError); } static void OnCDataBlockCallback( void* userData, const xmlChar* chars, int len ) { auto parser = static_cast(userData); assert(nullptr != parser); const std::string cdata(reinterpret_cast(chars), static_cast(len)); parser->OnCDataBlock(cdata); } ///////////////// End of SAX callback dispatchers ///////////////////// /////////////////////////////////////////////////////////////////////// // Helper function to get the xmlSAXHandler pointer with the callback // dispatcher functions already registered. static xmlSAXHandler* GetSaxHandler() { static xmlSAXHandler saxHandler = { nullptr, // internalSubset; nullptr, // isStandalone; nullptr, // hasInternalSubset; nullptr, // hasExternalSubset; nullptr, // resolveEntity; nullptr, // getEntity; nullptr, // entityDecl; nullptr, // notationDecl; nullptr, // attributeDecl; nullptr, // elementDecl; nullptr, // unparsedEntityDecl; nullptr, // setDocumentLocator; OnStartDocumentCallback, // startDocument; OnEndDocumentCallback, // endDocument; OnStartElementCallback, // startElement; OnEndElementCallback, // endElement; nullptr, // reference; OnCharactersCallback, // characters; nullptr, // ignorableWhitespace; nullptr, // processingInstruction; OnCommentCallback, // comment; OnWarningCallback, // warning; OnErrorCallback, // error; OnFatalErrorCallback, // fatalError; /* unused error() get all the errors */ nullptr, // getParameterEntity; OnCDataBlockCallback, // cdataBlock; nullptr, // externalSubset; 0, // initialized; /* The following fields are extensions available only on version 2 */ nullptr, // _private; nullptr, // startElementNs; nullptr, // endElementNs; nullptr // serror; }; return &saxHandler; } /////////////////////////////////////////////////////////////////////// // SaxParserBase implementation SaxParserBase::SaxParserBase() : m_ctxt(nullptr) { xmlSAXHandlerPtr saxHander = GetSaxHandler(); m_ctxt = xmlCreatePushParserCtxt(saxHander, NULL, NULL, 0, NULL); if (m_ctxt == nullptr) { throw SaxParserBaseException("Failed to create Xml parser context"); } // The instance pointer should be saved so that the callback // dispatcher functions can route the calls to the proper // SaxParserBase instance. m_ctxt->userData = this; } SaxParserBase::~SaxParserBase() { if (m_ctxt != nullptr) { xmlFreeParserCtxt(m_ctxt); } } #define MAX_SAX_CHUNK_SIZE 1024 void SaxParserBase::Parse(const std::string & doc) { if (m_ctxt == nullptr) { throw SaxParserBaseException("Xml parser context wasn't created correctly at the construction time"); } const char* buf = doc.c_str(); size_t totalLen = doc.length(); size_t remainingLen = totalLen; while (remainingLen > 0) { const size_t chunkSize = std::min((size_t)MAX_SAX_CHUNK_SIZE, remainingLen); const int terminate = (chunkSize == remainingLen); int result = xmlParseChunk(m_ctxt, buf, (int)chunkSize, terminate); if (result) { const int offsetBegin = totalLen - remainingLen; const int offsetEnd = offsetBegin + (int)chunkSize; std::ostringstream oss; oss << "xmlParseChunk error between offset " << offsetBegin << " and " << offsetEnd << " (return code = " << result << ")"; this->OnError(oss.str()); return; } remainingLen -= chunkSize; buf += chunkSize; } } void SaxParserBase::ParseChunk(std::string chunk, bool terminate) { if (m_ctxt == nullptr) { throw SaxParserBaseException("Xml parser context wasn't created correctly at the construction time"); } int result = xmlParseChunk(m_ctxt, chunk.c_str(), (int)chunk.length(), terminate); if (result) { std::ostringstream oss; oss << "xmlParseChunk error (return code = " << result << ")"; this->OnError(oss.str()); } } ================================================ FILE: Diagnostic/mdsd/mdsd/SaxParserBase.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _SAXPARSERBASE_HH_ #define _SAXPARSERBASE_HH_ #include #include #include #include class SaxParserBaseException : public std::runtime_error { public: SaxParserBaseException(const std::string& message) : std::runtime_error(message) {} SaxParserBaseException(const char* message) : std::runtime_error(message) {} }; /// /// A simple base class for a specific SAX parser. User of this class /// will derive a subclass from this and override necessary On...() methods /// to achieve their desired SAX parsing. Currently not supporting all /// the callbacks that LibXML2 supports. Should add all of them gradually. /// This base class's callback handler methods are all empty so that users /// don't have to implement all those methods, when they need only a few /// of them. /// class SaxParserBase { public: typedef std::unordered_map AttributeMap; SaxParserBase(); virtual ~SaxParserBase(); // Callbacks for various SAX parsing events virtual void OnStartDocument() {} virtual void OnEndDocument() {} virtual void OnComment(const std::string& comment) {} virtual void OnStartElement(const std::string& name, const AttributeMap& attributes) {} virtual void OnCharacters(const std::string& chars) {} virtual void OnEndElement(const std::string& name) {} virtual void OnWarning(const std::string& text) {} virtual void OnError(const std::string& text) {} virtual void OnFatalError(const std::string& text) {} virtual void OnCDataBlock(const std::string& text) {} /// /// Parse an entire XML document passed as a string. /// The entire XML document passed as a string /// void Parse(const std::string & doc); /// /// Parse a chunk of XML document passed as a string. /// This is needed so that a subclass don't have to use the /// libxml's C API to do the chunk parsing. We wanted to separate /// all I/Os from this class, so we'd need to provide this for /// any subclass. /// The XML chunk to be parsed, passed as a string /// Indicates whether the passed chunk is the last one /// in the whole XML document void ParseChunk(std::string chunk, bool terminate = false); private: xmlParserCtxtPtr m_ctxt; }; #endif // _SAXPARSERBASE_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/SchemaCache.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "SchemaCache.hh" #include "Trace.hh" #include #include "Crypto.hh" std::mutex SchemaCache::_mutex; SchemaCache* SchemaCache::_singleton; SchemaCache& SchemaCache::Get() { Trace trace(Trace::SchemaCache, "SchemaCache::Get"); // Double-check locking to ensure we instantiate the singleton exactly once. // We already needed the mutex for other purposes, so do this manually instead // of using std::call_once and std::once_flag. if (_singleton == nullptr) { _mutex.lock(); if (_singleton == nullptr) { _singleton = new SchemaCache(); trace.NOTE("Allocating singleton cache"); } _mutex.unlock(); } return *_singleton; } // Store the id, move the schema string from the argument into the object, compute the MD5 hash. // If the caller(s) all along the line enabled move semantics, we should wind up with the schema inside // this object without any copying. SchemaCache::Info::Info(SchemaCache::IdType id, std::string schema) : _id(id), _schema(std::move(schema)), _md5(Crypto::MD5HashString(_schema)) { } std::map & SchemaCache::Select(Kind kind) { switch(kind) { case XTable: return _XTableCache; case Bond: return _BondCache; default: throw std::invalid_argument("Access to SchemaCache of unknown kind"); } } bool SchemaCache::IsCached(SchemaCache::IdType id, Kind kind) noexcept { Trace trace(Trace::SchemaCache, "SchemaCache::IsCached"); try { std::lock_guard lock(_mutex); bool found = (Select(kind).count(id) > 0); if (trace.IsActive()) { std::ostringstream msg; msg << "Cache(" << kind; if (found) { msg << ") did "; } else { msg << ") did not "; } msg << "contain key " << id; trace.NOTE(msg.str()); } return found; } catch (std::exception& ex) { // We don't cache anything for unknown kinds of schemas trace.NOTE(std::string("Exception caught: ") + ex.what()); return false; } } SchemaCache::CachedType SchemaCache::Find(SchemaCache::IdType id, Kind kind) { auto& cache = Select(kind); // Lock the map down long enough to copy the result std::unique_lock lock(_mutex); auto it = cache.find(id); lock.unlock(); if (it == cache.end()) { std::ostringstream msg; msg << "SchemaCache(" << kind << ") does not contain id " << id; throw std::runtime_error(msg.str()); } return it->second; } void SchemaCache::Evict(SchemaCache::IdType id, Kind kind) noexcept { // Select is not nothrow, but the only exception it throws is one we want // to ignore (invalid kind). std::map::erase is nothrow. try { std::lock_guard lock(_mutex); (void) Select(kind).erase(id); } catch (...) { } } // Create an info structure by moving the schema into it. void SchemaCache::Insert(SchemaCache::IdType id, Kind kind, std::string schema) { Trace trace(Trace::SchemaCache, "SchemaCache::Insert"); auto entry = std::make_shared(id, std::move(schema)); auto & cache = Select(kind); std::lock_guard lock(_mutex); cache[id] = std::move(entry); if (trace.IsActive()) { std::ostringstream msg; msg << "Added id " << id << " to cache of type " << kind; trace.NOTE(msg.str()); } } std::ostream& operator<<(std::ostream& strm, SchemaCache::Kind kind) { switch(kind) { case SchemaCache::Kind::XTable: strm << "XTable"; break; case SchemaCache::Kind::Bond: strm << "Bond"; break; default: strm << "!Unknown!"; break; } return strm; } #ifdef ENABLE_TESTING void TEST__SchemaCache_Reset() { if (SchemaCache::_singleton) { delete SchemaCache::_singleton; SchemaCache::_singleton = nullptr; } } std::map& TEST__SchemaCache_Select(SchemaCache::Kind kind) { return SchemaCache::Get().Select(kind); } #endif // ENABLE_TESTING // vim: set ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/SchemaCache.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _SCHEMACACHE_HH_ #define _SCHEMACACHE_HH_ #pragma once #include #include #include #include #include #include #include #include "Crypto.hh" /* The notion of "schema" is tied to the CanonicalEvent. Each row in an MDS destination (table, Bond blob, etc.) can have its own schema. The generator of a CanonicalEvent (JSON, OMI, query engine) can "know" the schema. PipeStages that alter the CE can likewise "know" the altered schema. So... o Tag the CE (at instantiation) with its schema id as "known" by the source/generator. o PipeStage augmentors should map from input schema id to output schema id (the new, altered schema produced by the augmenting stage). These are general-purpose augmentor (e.g. "Add identity columns") and should literally keep a map of . o A unique augmenter (e.g. a configured query) should gets its own schema ID when the config is processed. If the augmentor performs an identity transformation, it should pass along the input schema ID(s). If it does its own projection, it should have its own (new) ID. */ class SchemaCache { public: ///////////// Types ////////// // The id type using IdType = unsigned long long; // The actual data kept in the cache. // The result of the Schema() method shares lifetime with the Info object; if you want it to live longer // you'll need to copy it. Same for the MD5Hash returned by the Hash() method. class Info { public: Info(SchemaCache::IdType id, std::string schema); SchemaCache::IdType Id() const { return _id; } const std::string& Schema() const { return _schema; } const Crypto::MD5Hash & Hash() const { return _md5; } private: SchemaCache::IdType _id; std::string _schema; Crypto::MD5Hash _md5; }; // The kinds of schema we can store enum Kind { Unknown, XTable, Bond }; // The value type is a shared pointer to the Info object. // When we return the shared_ptr to clients, the ref count on the actual object is // managed for us. When the last shared_ptr is deleted, the underlying object is cleaned up. using CachedType = std::shared_ptr; ///////////// Methods ////////// // Return the singleton instance of the SchemaCache static SchemaCache& Get(); // Allocate a new schema ID and return it. Using an atomic_long, so no locking needed. SchemaCache::IdType GetId() noexcept { return _nextId++; } // Check to see if a schema of a given kind has been cached for a given ID bool IsCached(SchemaCache::IdType id, SchemaCache::Kind kind) noexcept; // Return the cached schema of that kind for that id. Throws if none is cached. CachedType Find(SchemaCache::IdType id, SchemaCache::Kind kind); // Remove a cached schema. Silent if nothing is cached for the id/kind void Evict(SchemaCache::IdType id, SchemaCache::Kind kind) noexcept; // Insert a schema. Discard the currently cached schema, if any. // The schema is moved into the Info object, if possible. void Insert(SchemaCache::IdType id, SchemaCache::Kind kind, std::string schema); #ifdef ENABLE_TESTING friend void TEST__SchemaCache_Reset(); friend std::map& TEST__SchemaCache_Select(Kind kind); #endif ////////// Stream IO ////////// friend std::ostream& operator<<(std::ostream&, Kind); private: // Default constructor is private and used by the static accessor. Neither copy nor move or assignment // are allowed. SchemaCache() : _nextId(1) {} SchemaCache(const SchemaCache &) = delete; SchemaCache& operator=(const SchemaCache &) = delete; static SchemaCache * _singleton; // Points to the singleton instance of this class // As a static, the linker will ensure this is all zeroes, the correct bit pattern for nullptr static std::mutex _mutex; // Protects access to the cache std::atomic_ullong _nextId; // Next schema ID to use // We only have two kinds of schemas, so make each schema its own map and provide a simple // method to get a reference to the map for any particular kind. This is moderately scalable; // the Select method is a fast switch() on the Kind. At a certain point, it may become // smarter to change to a single map whose key is pair. Eliminate Select() and // simply build the right key wherever it's used. std::map _BondCache; std::map _XTableCache; // Return a reference to the map which caches the desired schema type std::map& Select(Kind kind); }; #ifdef ENABLE_TESTING void TEST__SchemaCache_Reset(); std::map& TEST__SchemaCache_Select(SchemaCache::Kind kind); #endif // ENABLE_TESTING #endif // _SCHEMACACHE_HH_ // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Signals.c ================================================ /* Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. */ #define _XOPEN_SOURCE #include #include #include #include #define STACK_DEPTH 50 #ifdef DOING_MEMCHECK extern void RunFinalCleanup(); #endif extern void LogStackTrace(int, void**, int); extern void LogAbort(); extern void CatchSigChld(int signo); extern void CleanupExtensions(); extern void SetCoreDumpLimit(); extern void TruncateAndClosePidPortFile(); extern void StopProtocolListenerMgr(); /* Signals on which we want to backtrace */ static int signalsToBacktrace[] = { SIGSEGV, SIGFPE, SIGILL, SIGTRAP, SIGBUS, SIGSTKFLT, SIGXFSZ }; static int CatchAndMaskAll(int sig, void(*handler)(int)) { struct sigaction sa; sa.sa_handler = handler; sigfillset(&sa.sa_mask); sa.sa_flags = 0; return sigaction(sig, &sa, 0); } static void SetBacktraceSignalHandler(void (*backtraceHandler)()) { int i; for (i = 0; i < sizeof(signalsToBacktrace) / sizeof(int); i++) { CatchAndMaskAll(signalsToBacktrace[i], backtraceHandler); } } static void ResetBacktraceSignalHandlersToDefault() { int i; for (i = 0; i < sizeof(signalsToBacktrace) / sizeof(int); i++) { signal(signalsToBacktrace[i], SIG_DFL); } signal(SIGABRT, SIG_DFL); // We need to reset the SIGABRT handler to default as well, so that our own SIGABRT handler will not be called on this path. } void CatchSigUsr1(int signo) { char msg[] = "In SIGUSR1 handler\n"; static int FirstTime = 1; #ifdef DOING_MEMCHECK write(2, msg, sizeof(msg)); if (FirstTime) { RunFinalCleanup(); FirstTime = 0; /* Let all registered atexit() handlers run */ exit(1); } else { /* What, still here? Die, dang it! */ _exit(1); } #else exit(1); #endif } void CatchSigUsr2(int signo) { extern void RotateLogs(); RotateLogs(); } void CatchSigHup(int signo) { extern void LoadNewConfiguration(); LoadNewConfiguration(); } void EmitStackTrace(int signo) { void *stack[STACK_DEPTH]; int count = backtrace(stack, STACK_DEPTH); LogStackTrace(signo, stack, count); } void CatchFatal(int signo) { // Code below can easily raise another signal // (SIGABRT, most likely), which shouldn't be handled by this handler // again, so we have to reset the handler to default. And we do that // for all signals on which we may want to dump stack trace. ResetBacktraceSignalHandlersToDefault(); EmitStackTrace(signo); CleanupExtensions(); TruncateAndClosePidPortFile(); } void CatchFatalAndExit(int signo) { CatchFatal(signo); _exit(signo); } void CatchFatalAndAbort(int signo) { CatchFatal(signo); abort(); } void CatchTerm(int signo) { CleanupExtensions(); StopProtocolListenerMgr(); // The Main thread will exit once ProtocolListenerMgr has stopped. /* struct sigaction sa_dfl; sa_dfl.sa_handler = SIG_DFL; sigemptyset(&sa_dfl.sa_mask); sa_dfl.sa_restorer = 0; sa_dfl.sa_flags = 0; sigaction(signo, &sa_dfl, 0); raise(signo); */ } void CatchSigAbort(int signo) { LogAbort(); // If the SIGABRT signal is ignored, or caught by a handler that returns, the abort() function // will still terminate the process. It does this by restoring the default disposition for // SIGABRT and then raising the signal for a second time. } void BlockSignals() { sigset_t ss; sigemptyset(&ss); sigaddset(&ss, SIGHUP); sigaddset(&ss, SIGALRM); sigprocmask(SIG_BLOCK, &ss, NULL); } void SetSignalCatchers(int coreDumpAtFatal) { CatchAndMaskAll(SIGUSR1, CatchSigUsr1); CatchAndMaskAll(SIGUSR2, CatchSigUsr2); CatchAndMaskAll(SIGINT, CatchTerm); CatchAndMaskAll(SIGTERM, CatchTerm); CatchAndMaskAll(SIGQUIT, CatchTerm); CatchAndMaskAll(SIGHUP, CatchSigHup); // SIGABRT shouldn't try to backtrace, because of a glibc bug (https://sourceware.org/bugzilla/show_bug.cgi?id=16159) // so catch it with a different handler, where it'll just log the event in mdsd.err and really abort. CatchAndMaskAll(SIGABRT, CatchSigAbort); void (*backtraceHandler)(); if (coreDumpAtFatal) { SetCoreDumpLimit(); backtraceHandler = CatchFatalAndAbort; } else { backtraceHandler = CatchFatalAndExit; } SetBacktraceSignalHandler(backtraceHandler); signal(SIGPIPE, SIG_IGN); struct sigaction sa_chld; sa_chld.sa_handler = CatchSigChld; sigemptyset(&sa_chld.sa_mask); sa_chld.sa_flags = SA_NOCLDSTOP; sigaction(SIGCHLD, &sa_chld, 0); } ================================================ FILE: Diagnostic/mdsd/mdsd/StoreType.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "StoreType.hh" #include #include #include "Utility.hh" extern "C" { #include } namespace StoreType { // Names should be all lower case, since from_string canonicalizes to lower case // before searching the map static std::map typeMap { { "local", StoreType::Type::Local }, { "xtable", StoreType::Type::XTable }, { "central", StoreType::Type::XTable }, { "jsonblob", StoreType::Type::XJsonBlob }, { "centraljson", StoreType::Type::XJsonBlob }, // For parity with WAD... { "file", StoreType::Type::File } }; static std::map nameLengthLimit { { StoreType::Type::None, 0 }, { StoreType::Type::XTable, 63 }, { StoreType::Type::XJsonBlob, PATH_MAX /* No explicit limit we've heard about this. */ }, { StoreType::Type::Local, 255 }, { StoreType::Type::File, PATH_MAX } }; static std::map needsSchemaGeneration { { StoreType::Type::None, false }, { StoreType::Type::XTable, true }, { StoreType::Type::XJsonBlob, false }, { StoreType::Type::Local, false }, { StoreType::Type::File, false } }; Type from_string(const std::string & n) { const auto &iter = typeMap.find(MdsdUtil::to_lower(n)); if (iter == typeMap.end()) { return None; } else { return iter->second; } } size_t max_name_length(StoreType::Type t) { const auto & iter = StoreType::nameLengthLimit.find(t); if (iter == StoreType::nameLengthLimit.end()) { return 0; } else { return iter->second; } } bool DoSchemaGeneration(StoreType::Type storetype) { const auto &iter = needsSchemaGeneration.find(storetype); if (iter == needsSchemaGeneration.end()) { throw std::domain_error("Don't know if schema generation is needed for StoreType " + std::to_string(storetype)); } return iter->second; } bool DoAddIdentityColumns(StoreType::Type storetype) { return (storetype != StoreType::Local); } }; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/StoreType.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _STORETYPE_HH_ #define _STORETYPE_HH_ #include namespace StoreType { enum Type { None, XTable, Bond, XJsonBlob, Local, File }; StoreType::Type from_string(const std::string &); size_t max_name_length(StoreType::Type t); bool DoSchemaGeneration(StoreType::Type storetype); bool DoAddIdentityColumns(StoreType::Type storetype); }; #endif // _STORETYPE_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/StreamListener.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. extern "C" { #include #include #include } #include #include #include #include #include #include #include "StreamListener.hh" #include "MdsTime.hh" #include "Trace.hh" #include "Utility.hh" int ReadFromSocket(int fd, char * buffer, size_t amount) { fd_set readfds; while (1) { FD_ZERO(&readfds); FD_SET(fd, &readfds); auto res = select(fd+1, &readfds, 0, 0, 0); if (0 == res) { // Spurious wakeup; they happen, honest. continue; } if (-1 == res) { auto saved_errno = errno; throw std::system_error(saved_errno, std::system_category(), "StreamListener select() failed."); } int len = read(fd, buffer, amount); if (-1 == len) { // Something unusual happened auto saved_errno = errno; if (EINTR == errno || EWOULDBLOCK == errno || EAGAIN == errno) { continue; } throw std::system_error(saved_errno, std::system_category(), "StreamListener read() failed."); } // If we got here, then we read some data or hit eof; either way, we're done return len; } } void * StreamListener::ProcessLoop() { Trace trace(Trace::EventIngest, "StreamListener::ProcessLoop"); const int msgbuflen=1024; char msgbuf[msgbuflen]; if (-1 == fcntl(fd(), F_SETFL, O_NONBLOCK)) { auto saved_errno = errno; Logger::LogError(std::string("StreamListener failed to set O_NONBLOCK: ").append(MdsdUtil::GetErrnoStr(saved_errno))); return 0; } buflen = 256 * 1024; trigger = 3 * (buflen>>2); // 75% full buffer = (char *)malloc(buflen + 1); // Always room to turn "byte array" into "string" if (0 == buffer) { Logger::LogError("Initial buffer alloc out of memory"); return(0); } current = buffer; // buffer points to the beginning of the allocated buffer. // buflen is the usable size of the buffer (which was allocated with 1 extra byte for a terminal NUL). // current points to the location at which we might try to write into the buffer. // When the buffer is empty, current==buffer // When the buffer is full, current==(buffer+buflen), a valid address at which a single byte can be written. while (1) { // Invariant: unparsed data in buffer is less than the threshold for expanding the buffer auto inuse = current - buffer; // How far we were in the old buffer if (inuse >= trigger) { // Sanity check: no legal message is bigger than N MiB if (inuse > 4*1024*1024) { std::ostringstream msg; msg << "Buffered incomplete JSON data (" << inuse << " bytes) exceeds max; probable desync. Buffer head:\n[["; msg << MdsdUtil::StringNCopy(buffer, 1024) << "]]\nDropping connection."; Logger::LogError(msg.str()); return(0); } // Resize the buffer TRACEINFO(trace, "Reallocate ingest buffer; was (buflen " << buflen << ", trigger " << trigger << ")"); if (trace.IsActive() && trace.IsAlsoActive(Trace::IngestContents)) { TRACEINFO(trace, "Old buffer start: [[" << MdsdUtil::StringNCopy(buffer, 1024) << "]]"); } buffer = (char *)realloc(buffer, 2 * buflen + 1); if (0 == buffer) { snprintf(msgbuf, msgbuflen, "Buffer realloc(%ld) out of memory", 2 * buflen + 1); Logger::LogError(msgbuf); return(0); } current = buffer + inuse; buflen *= 2; trigger *= 2; TRACEINFO(trace, "Now (buflen " << buflen << ", trigger " << trigger << ")"); } int len; try { len = ReadFromSocket(fd(), current, (buflen - inuse)); } catch (const std::exception& e) { Logger::LogError(e.what()); return 0; } if (0 == len) { // End of file - closed socket. snprintf(msgbuf, 1024, "End of file on thread %llx - exiting thread", (long long int)pthread_self()); trace.NOTE(msgbuf); return 0; } // OK, I have some characters. Question is - do I have at least one valid // JSON object? Best we can do is guess. // If the last character is a backslash, it's not safe to hand the buffer // to the parser; if the object is actually incomplete, the backslash will // escape the NUL terminator and the parser will go off the edge of the buffer. // Did I receive a right-brace in the most recent receive? If not, then // I can't possibly have a valid object; go read more. // If I saw a right brace, I *might* have a valid object; try to parse it. // If I get a NULL back from the parser, I have no valid object; go read more. // If I got a valid pointer, then I had at least one valid object, but they've // been parsed; the pointer tells me where the next object might begin, so // shuffle it to the top of the buffer and go read more. const char * cursor = current + len - 1; // Last character read if (*cursor == '\\') { // Not safe to parse, and there has to be more coming; go read more. current += len; continue; } while (cursor >= current) { if (*cursor == '}') break; cursor--; } if (cursor < current) { // Nope, can't be an object; go read more. current += len; continue; } // Found a right brace. I might have valid objects. Parse the full buffer. *(current+len) = '\0'; try { cursor = Listener::ParseBuffer(buffer, current+len); } catch (const Listener::exception &e) { std::ostringstream msg; msg << MdsTime() << ": closing connection due to JSON parse error: " << e.what(); Logger::LogError(msg.str()); return(0); } if (0 == cursor) { // Nope, no object; go read more. current += len; continue; } // OK, processed something. cursor points to the next possible start of object // (I can rely on ParseBuffer to have clobbered any trailing whitespace.) if (cursor == current+len) { current = buffer; // Processed everything; nothing remains } else { int delta = current + len - cursor; // Remaining unprocessed characters (void) memmove(buffer, cursor, delta); current = buffer + delta; } } /* NOTREACHED */ } // vim: set ai sw=2 expandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/StreamListener.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _STREAMLISTENER_HH_ #define _STREAMLISTENER_HH_ #include #include #include "Listener.hh" /// Listens for JSON-encoded events on a TCP socket class StreamListener : public Listener { private: StreamListener(const StreamListener&); // Do not define; copy construction forbidden StreamListener& operator=(const StreamListener &); // Ditto for assignment char * buffer = nullptr; // Data received from client size_t buflen = 0; // Size of buffer ptrdiff_t trigger = 0; // Offset into buffer of leftover data that causes an increase in buffer size char * current = nullptr; // Point at which new data will be added public: StreamListener(int fd) : Listener(fd) {} virtual ~StreamListener() { if (buffer) free(buffer); } void * ProcessLoop(); }; // vim: set ai sw=2 #endif // _STREAMLISTENER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/Subscription.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Subscription.hh" #include "Batch.hh" #include "Credentials.hh" #include "MdsEntityName.hh" #include "PipeStages.hh" #include "LocalSink.hh" #include "Trace.hh" Subscription::Subscription(LocalSink *sink, const MdsEntityName &target, Priority pr, const MdsTime& interval) : ITask(interval), _sink(sink), _target(target), _priority(pr), _head(nullptr), _tail(nullptr) { Trace trace(Trace::ConfigLoad, "Subscription constructor(table ref)"); common_constructor(); } Subscription::Subscription(LocalSink *sink, MdsEntityName &&target, Priority pr, const MdsTime& interval) : ITask(interval), _sink(sink), _target(target), _priority(pr), _head(nullptr), _tail(nullptr) { Trace trace(Trace::ConfigLoad, "Subscription constructor(table move)"); common_constructor(); } void Subscription::common_constructor() { Trace trace(Trace::ConfigLoad, "Subscription common constructor path"); _sink->SetRetentionPeriod(interval()); if (trace.IsActive()) { std::ostringstream msg; msg << "Retention period " << _sink->RetentionSeconds(); trace.NOTE(msg.str()); } } // Initial start time is a few seconds past the end of the current interval MdsTime Subscription::initial_start() { Trace trace(Trace::EventIngest, "Subscription::initial_start"); MdsTime start; // Default constructor sets it to "now" start += interval(); start = start.Round(interval().to_time_t()); start += MdsTime(2 + random()%5, random()%1000000); if (trace.IsActive()) { std::ostringstream msg; msg << "Initial time for event: " << start; trace.NOTE(msg.str()); } return start; } void Subscription::AddStage(PipeStage *stage) { if (! _tail) { // This is the first stage in the pipeline; set the head to point here _head = stage; } else { // There's already a pipeline; make the old tail point to the newly-added stage _tail->AddSuccessor(stage); } // Either way, we have a new tail in the pipeline _tail = stage; } // Pull everything in the sink on the interval [start, start+duration) // For each event, call _head->Process(new CanonicalEntity(event)) void Subscription::execute(const MdsTime& startTime) { Trace trace(Trace::EventIngest, "Subscription::execute"); if (trace.IsActive()) { std::ostringstream msg; msg << "Start time " << startTime << ", end time " << startTime + interval(); trace.NOTE(msg.str()); } _head->Start(startTime); try { _sink->Foreach(startTime, interval(), [this](const CanonicalEntity& ce){ _head->Process(new CanonicalEntity(ce)); }); } catch (std::exception & ex) { trace.NOTE(std::string("Exception leaked: ") + ex.what()); } trace.NOTE("All lines processed"); _sink->Flush(); // Tell the sink to do its housekeeping _head->Done(); } std::ostream& operator<<(std::ostream& os, const Subscription& sub) { os << &sub << " (Event " << sub._target << ", interval " << sub._priority.Duration() << ")"; return os; } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/Subscription.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _SUBSCRIPTION_HH_ #define _SUBSCRIPTION_HH_ #include #include "Priority.hh" #include "MdsEntityName.hh" #include "CanonicalEntity.hh" #include "Pipeline.hh" #include "ITask.hh" class LocalSink; class Subscription : public ITask { friend std::ostream& operator<<(std::ostream& os, const Subscription& sub); public: //Subscription(const std::string &ev, bool, const MdsdConfig*, const std::string &acct, StoreType::Type, Priority); Subscription(LocalSink *sink, const MdsEntityName& target, Priority, const MdsTime& interval); Subscription(LocalSink *sink, MdsEntityName&& target, Priority, const MdsTime& interval); ~Subscription() { if (_head) delete _head; _head = nullptr; } void AddStage(PipeStage *stage); const MdsEntityName& target() const { return _target; } Priority priority() const { return _priority; } time_t Duration() const { return interval().to_time_t(); } protected: // Returns the time at which the first call should be made MdsTime initial_start(); // Invoked regularly to process data for the interval() seconds beginning at this time void execute(const MdsTime&); private: Subscription(); void common_constructor(); LocalSink *_sink; const MdsEntityName _target; const Priority _priority; // Ingest processing pipeline. When the subscription is deleted, the destructor must tear // down the pipeline. The teardown is recursive; delete the head, and it'll delete its // successor before finishing up. PipeStage *_head; PipeStage *_tail; }; std::ostream& operator<<(std::ostream& os, const Subscription& sub); #endif // _SUBSCRIPTION_HH_ // vim: set ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/TableColumn.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "TableColumn.hh" void TableColumn::AppendXmlSchemaElement(std::string& xmlbody) const { xmlbody += ""; } ================================================ FILE: Diagnostic/mdsd/mdsd/TableColumn.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _TABLECOLUMN_HH_ #define _TABLECOLUMN_HH_ #include "MdsValue.hh" class TableColumn { public: TableColumn(const std::string& n, const std::string& t, typeconverter_t& c) : _name(n), _mdstype(t), _converter(c) {} ~TableColumn() {} const std::string& Name() const { return _name; } const std::string& MdsType() const { return _mdstype; } /// Append to the body the MDS XML "schema" definition element for this column /// The XML body to which the generated element should be appended void AppendXmlSchemaElement(std::string& xmlbody) const; /// Convert a cJSON object to the configured MDS type /// The cJSON entity to be converted /// Pointer to a newly-allocated MdsValue. Returns 0 if the conversion failed. MdsValue* Convert(cJSON * in) const { return _converter(in); } private: TableColumn(); const std::string _name; const std::string _mdstype; const typeconverter_t _converter; }; #endif //_TABLECOLUMN_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/TableSchema.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "TableSchema.hh" #include "TableColumn.hh" #include "Engine.hh" #include #include TableSchema::~TableSchema() { for (TableColumn * tblcol : _columns) { delete tblcol; } } TableSchema::ErrorCode TableSchema::AddColumn(const std::string& name, const std::string& srctype, const std::string& mdstype) { if (! _legal_types.count(srctype)) return BadSrcType; if (! _legal_mdstypes.count(mdstype)) return BadMdsType; for (TableColumn * tblcol : _columns) { if (tblcol->Name() == name) return DupeColumn; } typeconverter_t converter; if (! Engine::GetEngine()->GetConverter(srctype, mdstype, converter)) { return NoConverter; } auto newcolumn = new TableColumn(name, mdstype, converter); _columns.push_back(newcolumn); return Ok; } void TableSchema::PushColumnInfo(std::back_insert_iterator > > inserter) const { for (const auto & tblcol : _columns) { *(inserter++) = std::make_pair(tblcol->Name(), tblcol->MdsType()); } } std::set TableSchema::_legal_types = { "bool", "int", "str", "double", "int-timet", "double-timet", "str-rfc3339", "str-rfc3194" }; std::set TableSchema::_legal_mdstypes = { "mt:bool", "mt:wstr", "mt:float64", "mt:int32", "mt:int64", "mt:utc" }; // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/TableSchema.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _TABLESCHEMA_HH_ #define _TABLESCHEMA_HH_ #include "MdsValue.hh" #include "TableColumn.hh" #include #include #include #include #include class TableSchema { public: TableSchema(const std::string& n) : _name(n) {} ~TableSchema(); enum ErrorCode { Ok = 0, NoConverter = 1, DupeColumn = 2, BadSrcType = 3, BadMdsType = 4 }; /// Add a column to this schema. /// Name of the column /// The JSON type for the column, as the data arrives in an event /// The MDS type for the column in MDS ErrorCode AddColumn(const std::string& n, const std::string& srctype, const std::string& mdstype); // Act kinda like a container; allow iterators on the vector of TableColumn*. typedef std::vector::iterator iterator; typedef std::vector::const_iterator const_iterator; iterator begin() { return _columns.begin(); } const_iterator begin() const { return _columns.begin(); } iterator end() { return _columns.end(); } const_iterator end() const { return _columns.end(); } size_t Size() const { return _columns.size(); } /// Push pairs of [column name, column typename] into a vector void PushColumnInfo(std::back_insert_iterator > >) const; const std::string& Name() const { return _name; } private: TableSchema(); const std::string _name; std::vector _columns; static std::set _legal_types; static std::set _legal_mdstypes; }; #endif //_TABLESCHEMA_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/TermHandler.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Logger.hh" #include extern "C" { void EmitStackTrace(int signo); } // Log uncaught exception before terminate the process. void TerminateHandler() { try { throw; } catch(const std::exception& e) { Logger::LogError("Error: mdsd is terminated with exception: " + std::string(e.what())); } catch(...) { Logger::LogError("Error: mdsd is terminated with unknown exception."); } EmitStackTrace(0); abort(); } ================================================ FILE: Diagnostic/mdsd/mdsd/Version.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Version.hh" #include #define QUOTE(x) #x #define VAL(x) QUOTE(x) #define STATIC_VER VAL(MAJOR) "." VAL(MINOR) "." VAL(PATCH) "+" VAL(BUILD_NUMBER) namespace Version { const std::string Version(STATIC_VER); } // vim: se sw=8 ================================================ FILE: Diagnostic/mdsd/mdsd/Version.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _VERSION_HH_ #define _VERSION_HH_ #define MAJOR 1 #define MINOR 6 #define PATCH 100 #include namespace Version { extern const std::string Version; } #endif //_VERSION_HH_ // vim: se sw=8 ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobBlockCountsMgr.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XJsonBlobBlockCountsMgr.hh" #include "Utility.hh" #include "Trace.hh" #include #include #include #include #include #include #include XJsonBlobBlockCountsMgr& XJsonBlobBlockCountsMgr::GetInstance() { static XJsonBlobBlockCountsMgr s_instance; return s_instance; } void XJsonBlobBlockCountsMgr::SetPersistDir(const std::string& persistDir, bool mdsdConfigValidationOnly) { Trace trace(Trace::JsonBlob, "XJsonBlobBlockCountsMgr::SetPersistDir"); TRACEINFO(trace, "persistDir=\"" << persistDir << "\""); if (persistDir.empty()) { throw std::invalid_argument("persistDir can't be empty."); } m_persistDir = persistDir; m_mdsdConfigValidationOnly = mdsdConfigValidationOnly; } void XJsonBlobBlockCountsMgr::CreatePersistDirIfNotDone() { Trace trace(Trace::JsonBlob, "XJsonBlobBlockCountsMgr::CreatePersistDirIfNotDone"); if (m_persistDirCreated || m_mdsdConfigValidationOnly) { return; } if (m_persistDir.empty()) { throw std::runtime_error("Jsonblob block counts persist dir is not set."); } MdsdUtil::CreateDirIfNotExists(m_persistDir, 01755); m_persistDirCreated = true; } pplx::task XJsonBlobBlockCountsMgr::ReadBlockCountAsync( const std::string& containerName, const std::string& blobName) const { Trace trace(Trace::JsonBlob, "XJsonBlobBlockCountsMgr::ReadBlockCountAsync"); if (m_mdsdConfigValidationOnly) { throw std::runtime_error("XJsonBlobBlockCountsMgr::ReadBlockCountAsync: Can't be called when mdsd config validation only"); } if (containerName.empty()) { throw std::invalid_argument("XJsonBlobBlockCountsMgr::ReadBlockCountAsync: containerName can't be empty."); } if (blobName.empty()) { throw std::invalid_argument("XJsonBlobBlockCountsMgr::ReadBlockCountAsync: blobName can't be empty."); } std::string file_path(m_persistDir); file_path.append("/").append(containerName); // If there's no block-count file, then the block count is just 0. if (!MdsdUtil::IsRegFileExists(file_path)) { return pplx::task_from_result((size_t)0); } return concurrency::streams::fstream::open_istream(file_path) .then([=](concurrency::streams::istream inFile) -> pplx::task { concurrency::streams::container_buffer streamBuffer; return inFile.read_to_end(streamBuffer) .then([=](size_t bytesRead) -> pplx::task { if (bytesRead == 0 && inFile.is_eof()) { // Invalid file format. Treat it silently as 0 block count. return pplx::task_from_result((size_t)0); } std::istringstream iss(streamBuffer.collection()); std::string blobNameInFile; iss >> blobNameInFile; if (blobNameInFile != blobName) { // Persisted block count is for the past, so the block count for the current blob should be 0. return pplx::task_from_result((size_t)0); } size_t blockCountInFile; iss >> blockCountInFile; return pplx::task_from_result(blockCountInFile); }) .then([=](size_t blockCount) -> pplx::task { return inFile.close() .then([=]() -> pplx::task { return pplx::task_from_result(blockCount); }); }); }); } pplx::task XJsonBlobBlockCountsMgr::WriteBlockCountAsync( const std::string& containerName, const std::string& blobName, const size_t blockCount) const { Trace trace(Trace::JsonBlob, "XJsonBlobBlockCountsMgr::WriteBlockCountAsync"); if (m_mdsdConfigValidationOnly) { throw std::runtime_error("XJsonBlobBlockCountsMgr::WriteBlockCountAsync: Can't be called when mdsd config validation only"); } if (containerName.empty()) { throw std::invalid_argument("XJsonBlobBlockCountsMgr::WriteBlockCountAsync: containerName can't be empty."); } if (blobName.empty()) { throw std::invalid_argument("XJsonBlobBlockCountsMgr::WriteBlockCountAsync: blobName can't be empty."); } if (blockCount == 0) { throw std::invalid_argument("XJsonBlobBlockCountsMgr::WriteBlockCountAsync: 0 blockCount is not allowed."); } // m_persistDir + "/" + containerName is the full file path. // blobName and blockCount are the only content in the file. // First write to a tmp file path and then rename it to the correct path std::string file_path(m_persistDir); file_path.append("/").append(containerName); std::string file_path_tmp(file_path); file_path_tmp.append(".tmp"); return concurrency::streams::fstream::open_ostream(file_path_tmp) .then([=](concurrency::streams::ostream outFile) -> pplx::task { std::string content(blobName); content.append("\n").append(std::to_string(blockCount)).append("\n"); return outFile.print(content) .then([=](size_t) -> pplx::task { return outFile.close(); }); }) .then([=]() -> pplx::task { if (-1 == rename(file_path_tmp.c_str(), file_path.c_str())) { auto errnum = errno; std::error_code ec(errnum, std::system_category()); throw std::runtime_error(std::string("XJsonBlobBlockCountsMgr::WriteBlockCountAsync: " "rename(").append(file_path_tmp).append(", ").append(file_path).append(" failed. " "Reason: ").append(ec.message())); } return pplx::task_from_result(); }); } ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobBlockCountsMgr.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __XJSONBLOBBLOCKCOUNTSMGR_HH__ #define __XJSONBLOBBLOCKCOUNTSMGR_HH__ #include #include // Singleton pattern class XJsonBlobBlockCountsMgr { public: static XJsonBlobBlockCountsMgr& GetInstance(); XJsonBlobBlockCountsMgr(const XJsonBlobBlockCountsMgr&) = delete; XJsonBlobBlockCountsMgr(XJsonBlobBlockCountsMgr&&) = delete; XJsonBlobBlockCountsMgr& operator=(const XJsonBlobBlockCountsMgr&) = delete; XJsonBlobBlockCountsMgr& operator=(XJsonBlobBlockCountsMgr&&) = delete; // Called from main() after mdsd_prefix is determined. void SetPersistDir(const std::string& persistDir, bool mdsdConfigValidationOnly); // Called from XJsonBlobSink::XJsonBlobSink() void CreatePersistDirIfNotDone(); pplx::task ReadBlockCountAsync(const std::string& containerName, const std::string& blobName) const; pplx::task WriteBlockCountAsync(const std::string& containerName, const std::string& blobName, const size_t blockCount) const; private: XJsonBlobBlockCountsMgr() : m_persistDirCreated(false), m_mdsdConfigValidationOnly(false) {} ~XJsonBlobBlockCountsMgr() {} bool m_persistDirCreated; bool m_mdsdConfigValidationOnly; std::string m_persistDir; // e.g., "/var/run/mdsd/default_jsonblob_block_counts" }; #endif // __XJSONBLOBBLOCKCOUNTSMGR_HH__ ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobRequest.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XJsonBlobRequest.hh" #include "XJsonBlobBlockCountsMgr.hh" #include #include #include "MdsTime.hh" #include "Constants.hh" #include "Crypto.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" #include "AzureUtility.hh" #include "Version.hh" #include #include #include #include #include #include #include #include #include #include #include XJsonBlobRequest::XJsonBlobRequest( const XJsonBlobSink::RequestInfo& info, const MdsTime& blobBaseTime, const std::string& blobIntervalISO8601Duration, const std::string& containerName, const std::string& reqId, const std::shared_ptr& blocklist) : _info(info), _blobBaseTime(blobBaseTime), _containerName(containerName), _requestId(reqId), _totalDataBytes(0), _blockList(blocklist) { Trace trace(Trace::JsonBlob, "XJBR::XJBR, reqId=" + _requestId); if (blobIntervalISO8601Duration.empty()) { throw std::invalid_argument("Empty string param (blobIntervalISO8601Duration)"); } if (containerName.empty()) { throw std::invalid_argument("Empty string param (containerName)"); } if (reqId.empty()) { throw std::invalid_argument("Empty string param (reqId)"); } if (!blocklist) { throw std::invalid_argument("Null blocklist"); } // Blob name example: resourceId=test_resource_id/i=agentIdentityHash/y=2015/m=05/d=03/h=00/m=00/name=PT1H.json std::stringstream blobnamestr; if (!_info.primaryPartitionField.empty()) { blobnamestr << _info.primaryPartitionField << '/'; } if (!_info.agentIdentityHash.empty()) { blobnamestr << _info.agentIdentityHash << '/'; } blobnamestr << blobBaseTime.to_strftime("y=%Y/m=%m/d=%d/h=%H/m=%M/"); if (!_info.partitionFields.empty()) { blobnamestr << _info.partitionFields << '/'; } blobnamestr << blobIntervalISO8601Duration <<".json"; _blobName = blobnamestr.str(); TRACEINFO(trace, "Preliminary blobname " << _blobName); } XJsonBlobRequest::~XJsonBlobRequest() { // Just to see that this fire-and-forget object is really destructed. Trace trace(Trace::JsonBlob, "XJBR::~XJBR, reqId=" + _requestId); } static const std::string jsonRowSeparator(",\n"); void XJsonBlobRequest::AddJsonRow(std::string&& jsonRow) { Trace trace(Trace::JsonBlob, "XJBR::AddJsonRow"); if (jsonRow.empty()) { TRACEINFO(trace, "Empty jsonRow string passed. Nothing to do. Return"); return; } if (!_dataset.empty()) { _totalDataBytes += jsonRowSeparator.length(); } _totalDataBytes += jsonRow.size(); _dataset.emplace_back(std::move(jsonRow)); TRACEINFO(trace, "# rows in dataset = " << _dataset.size() << ", total data bytes = " << _totalDataBytes); } static std::string GetStorageExceptionDetails(const azure::storage::storage_exception& e) { std::ostringstream oss; oss << "Storage exception: " << e.what(); azure::storage::request_result result = e.result(); azure::storage::storage_extended_error err = result.extended_error(); if (!err.message().empty()) { oss << ", Extended info: " << err.message(); } oss << ", HTTP status code: " << std::to_string(result.http_status_code()); return oss.str(); } class XJBRAsyncTaskError : public std::runtime_error { public: XJBRAsyncTaskError(const std::string& taskName, const std::string& message) : std::runtime_error(message), _taskName(taskName) {} std::string GetTaskName() const { return _taskName; } private: std::string _taskName; }; // Used to synchronize access to a BlockList. A std::shared_ptr should stay alive // while exclusive access to the BlockList is needed across tasks/threads. The wrapping // std::shared_ptr will provide copy counting, and will only deconstruct // the BlockListOwner when the last instance of the wrapping std::shared_ptr is // deconstructed. struct BlockListOwner { std::shared_ptr _blockList; const std::string _ownerName; const std::string _requestId; BlockListOwner(std::shared_ptr blockList, const std::string& ownerName, const std::string& requestId) : _blockList(blockList), _ownerName(ownerName), _requestId(requestId) { Trace trace(Trace::JsonBlob, "BlockListOwner::BlockListOwner"); TRACEINFO(trace, "Attempting to set block list owner for " << _requestId << " to " << _ownerName); _blockList->LockIfOwnedByNoneThenSetOwner(_ownerName); TRACEINFO(trace, "Set block list owner for " << _requestId << " to " << _ownerName); } ~BlockListOwner() { Trace trace(Trace::JsonBlob, "BlockListOwner::~BlockListOwner"); TRACEINFO(trace, "Resetting block list owner for " << _requestId << " (currently " << _ownerName << ")"); _blockList->ResetOwnerAndNotify(); } }; /*static*/ void XJsonBlobRequest::Send( std::shared_ptr req, const std::string& connString) { Trace trace(Trace::JsonBlob, "XJBR::Send id=" + req->_requestId); if (!req) { Logger::LogWarn("XJBR::Send(): Null request was passed. This shouldn't happen. Returning anyway..."); return; } if (req->_dataset.empty()) { Logger::LogWarn("Nothing to upload to the XJsonBlob blob " + req->_blobName + ". Returning..."); return; } try { TRACEINFO(trace, "Get reference to container/blob " << req->_containerName << "/" << req->_blobName); auto cloudStorageAccount = azure::storage::cloud_storage_account::parse(connString); // The endpoint URL and storage account are not really needed, but just for informational purpose... auto endpointURL = cloudStorageAccount.blob_endpoint().primary_uri().to_string(); std::string storageAccountName = MdsdUtil::GetStorageAccountNameFromEndpointURL(endpointURL); TRACEINFO(trace, "Storage endpoint URL: " << endpointURL << ", extracted storage account name: " << storageAccountName << ", requestId: " << req->_requestId); req->_blobRef = cloudStorageAccount .create_cloud_blob_client() .get_container_reference(req->_containerName) .get_block_blob_reference(req->_blobName); // Start only when the mutex is not owned by any other request. // Owner name really doesn't matter as long as it's non-empty. // requestId is only for logging. auto blockListOwner = std::make_shared(req->_blockList, req->_blobName, req->_requestId); XJsonBlobRequest::UploadNewBlockAsync(req) .then([req]() -> pplx::task { // This is a value-based continuation, so if the previous task throws, // this task is not executed, so no need to do wait on prev_task. return XJsonBlobRequest::UploadBlockListAsync(req); }) .then([req]() -> pplx::task { // Another value-based continuation return XJsonBlobBlockCountsMgr::GetInstance().WriteBlockCountAsync(req->_containerName, req->_blobName, req->_blockList->get().size()); }) // Copy capture the BlockListOwner so that it stays alive through this // continuation task. .then([req, blockListOwner](pplx::task prev_task) { // This is a task-based continuation, so this task will be executed // even if any previous task throws. Trace trace(Trace::JsonBlob, "XJBR::Send final continuation task, req id=" + req->_requestId); try { // Wait, to handle prev async task exceptions right away prev_task.wait(); // There were no exceptions if we reached this point. if (trace.IsActive()) { TRACEINFO(trace, "Added new block to blob [" << req->_blobName << "]. Now there are " << req->_blockList->get().size() << " blocks in the blob."); } } catch (const XJBRAsyncTaskError& e) { Logger::LogError(e.GetTaskName().append(": ").append(e.what())); } catch (const std::exception& e) { Logger::LogError(std::string("[XJBR::UploadBlockListCompletion]: ").append(e.what())); } catch (...) { // Don't leak any exception from this async function body Logger::LogError("[XJBR::UploadBlockListCompletion]: Unknown exception"); } }); } catch (const azure::storage::storage_exception& e) { Logger::LogError("Storage exception generated while starting async blob write: " + GetStorageExceptionDetails(e)); } catch (const std::exception& e) { Logger::LogError(std::string("Exception generated while starting async blob write: ").append(e.what())); } catch (...) { Logger::LogError("Unknown exception generated while starting async blob write"); } } static std::string GetBase64HashString(const std::string& content) { azure::storage::core::hash_provider provider = azure::storage::core::hash_provider::create_md5_hash_provider(); provider.write((const unsigned char*)content.c_str(), content.length()); provider.close(); return provider.hash(); } static constexpr size_t maxBlocksInBlob = 50000; static const std::string first_block_id = utility::conversions::to_base64(0); static const std::string first_block_content = "{\"records\":[\n"; static const std::string last_block_id = utility::conversions::to_base64(maxBlocksInBlob - 1); // 49999 static const std::string last_block_content = "\n]}"; /*static*/ pplx::task XJsonBlobRequest::UploadNewBlockAsync( const std::shared_ptr& req) { Trace trace(Trace::JsonBlob, "XJBR::UploadNewBlock id=" + req->_requestId); if (!req) { throw std::invalid_argument("Null shared_ptr"); } // Handy references auto& blobRef = req->_blobRef; auto& blockList = req->_blockList->get(); auto& blobName = req->_blobName; auto& newBlockId = req->_newBlockId; auto& newBlockContent = req->_newBlockContent; if (blockList.size() >= maxBlocksInBlob) { std::ostringstream ss; ss << "Can't add any more block to blob " << blobName << ". There are already max blobs (" << blockList.size() << ") in the blob."; throw XJBRAsyncTaskError("XJBR::UploadNewBlockAsync", ss.str()); } if (!blockList.empty() && blockList.size() < 2) { throw XJBRAsyncTaskError("XJBR::UploadNewBlockAsync", "Blob format error: No first/last blocks in " + blobName + ". Returning..."); } if (!blockList.empty() && (blockList.front().id() != first_block_id || blockList.back().id() != last_block_id)) { throw XJBRAsyncTaskError("XJBR::UploadNewBlockAsync", "Blob format error: First block id (" + blockList.front().id() + ") or last block id (" + blockList.back().id() + ") is incorrect in " + blobName + ". Returning."); } std::vector> blockUploadTasks; // maximum 3 uploads if (blockList.empty()) { TRACEINFO(trace, "Blob " << blobName << " is empty. Adding first/last blocks."); auto first_block_stream = concurrency::streams::bytestream::open_istream(first_block_content); auto taskUploadFirstBlock = blobRef.upload_block_async(first_block_id, first_block_stream, GetBase64HashString(first_block_content)); blockUploadTasks.push_back(taskUploadFirstBlock); auto last_block_stream = concurrency::streams::bytestream::open_istream(last_block_content); auto taskUploadLastBlock = blobRef.upload_block_async(last_block_id, last_block_stream, GetBase64HashString(last_block_content)); blockUploadTasks.push_back(taskUploadLastBlock); } // Add the new block. New block's id # is blockList.size() - 2 (first/last) + 1 (new block). // Above is correct only for non-empty block list. Empty block list case needs to be handled // as a special case, as we don't want to update the block list until blocks are really uploaded. size_t newBlockNum = blockList.empty() ? 1 : (blockList.size() - 1); newBlockId = utility::conversions::to_base64(newBlockNum); TRACEINFO(trace, "Adding a new block (numeric ID=" << newBlockNum << ", base64 ID=" << newBlockId << ") to blob " << blobName); // Construct the new block content. newBlockContent.reserve(req->_totalDataBytes + jsonRowSeparator.length()); // + 2 for possible preceding ",\n" if (blockList.size() > 2) { // Not the first content block, so prepend "," newBlockContent.append(jsonRowSeparator); } bool first = true; for (const auto& row : req->_dataset) { if (first) { first = false; } else { newBlockContent.append(jsonRowSeparator); } newBlockContent.append(row); } auto new_block_stream = concurrency::streams::bytestream::open_istream(newBlockContent); auto taskUploadContentBlock = blobRef.upload_block_async(newBlockId, new_block_stream, GetBase64HashString(newBlockContent)); blockUploadTasks.push_back(taskUploadContentBlock); return pplx::when_all(blockUploadTasks.begin(), blockUploadTasks.end()); } /*static*/ pplx::task XJsonBlobRequest::UploadBlockListAsync(const std::shared_ptr& req) { Trace trace(Trace::JsonBlob, "XJBR::UploadBlockListAsync, req id=" + req->_requestId); // handy references auto& request = *req; auto& blockList = request._blockList->get(); // Update block list only after block(s) is/are uploaded successfully if (blockList.empty()) { blockList.emplace_back(azure::storage::block_list_item(first_block_id)); blockList.emplace_back(azure::storage::block_list_item(last_block_id)); } blockList.insert(blockList.end() - 1, azure::storage::block_list_item(request._newBlockId)); // Finally upload the block list! return request._blobRef.upload_block_list_async(blockList); } /*static*/ void XJsonBlobRequest::ReconstructBlockListIfNeeded(std::shared_ptr req) { Trace trace(Trace::JsonBlob, "XJBR::ReconstructBlockListIfNeeded"); if (!req->_blockList->get().empty()) { throw std::runtime_error("XJBR::ReconstructBlockListIfNeeded: Block list is not empty."); } // Start only when the mutex is not owned by any other request. // Owner name really doesn't matter as long as it's non-empty. auto blockListOwner = std::make_shared(req->_blockList, req->_blobName, req->_requestId); // Copy capture the BlockListOwner so that it stays alive through the // continuation task. XJsonBlobBlockCountsMgr::GetInstance().ReadBlockCountAsync(req->_containerName, req->_blobName) .then([req, blockListOwner](pplx::task prev_task) { Trace trace(Trace::JsonBlob, "XJBR::ReconstructBlockListIfNeeded continuation"); TRACEINFO(trace, "In XJBR::ReconstructBlockListIfNeeded continuation."); try { auto blockCount = prev_task.get(); TRACEINFO(trace, "Obtained blockCount=" << blockCount << " for container=" << req->_containerName << " and blob=" << req->_blobName); if (blockCount == 0) { return; } if (blockCount < 3 // A persisted block count is always at least 3 blocks. // "{ ..." for first block, "}" for last block, at least one content block. || blockCount > maxBlocksInBlob) { Logger::LogError(std::string("Invalid block count (").append(std::to_string(blockCount)) .append(") returned from XJBBlockCountsMgr::ReadBlockCount. " "Valid block count is at least 3 and at most ").append(std::to_string(maxBlocksInBlob)) .append(". Block list won't be reconstructed.")); return; } // Finally we can reconstruct the block list. auto& blockList = req->_blockList->get(); blockList.emplace_back(azure::storage::block_list_item(first_block_id)); size_t lastBlockNum = blockCount - 2; for (size_t blockNum = 1; blockNum <= lastBlockNum; blockNum++) { blockList.emplace_back(azure::storage::block_list_item(utility::conversions::to_base64(blockNum))); } blockList.emplace_back(azure::storage::block_list_item(last_block_id)); } catch (std::exception& e) { Logger::LogError(std::string("Exception thrown from XJBBlockCountsMgr::ReadBlockCount. " "Block list can't be reconstructed. Exception message: ").append(e.what())); } catch (...) { Logger::LogError("Unknown exception thrown from XJBBlockCountsMgr::ReadBlockCount. " "Block list can't be reconstructed"); } }); TRACEINFO(trace, "After ReadBlockCountAsync."); } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobRequest.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _XJSONBLOBREQUEST_HH_ #define _XJSONBLOBREQUEST_HH_ #include "XJsonBlobSink.hh" #include "MdsTime.hh" #include #include #include #include #include #include #include class XJsonBlobRequest { public: XJsonBlobRequest( const XJsonBlobSink::RequestInfo& info, const MdsTime& blobBaseTime, const std::string& blobIntervalISO8601Duration, const std::string& containerName, const std::string& reqId, const std::shared_ptr& blocklist); ~XJsonBlobRequest(); static void Send( std::shared_ptr req, const std::string & connString); const std::string & UUID() const { return _requestId; } size_t EstimatedSize() const { return _totalDataBytes; } void AddJsonRow(std::string&& jsonRow); static void ReconstructBlockListIfNeeded(std::shared_ptr req); private: static pplx::task UploadNewBlockAsync(const std::shared_ptr& req); static pplx::task UploadBlockListAsync(const std::shared_ptr& req); XJsonBlobSink::RequestInfo _info; std::string _containerName; std::string _blobName; MdsTime _blobBaseTime; // Base time for the current blob // As we add a new _rowbuf to the collection, we accumulate its size so we know when we hit // maximum length. size_t _totalDataBytes; std::vector _dataset; // UUID for this request; attached to storage request(s) end-to-end std::string _requestId; // Async request handling members azure::storage::cloud_block_blob _blobRef; std::shared_ptr _blockList; std::string _newBlockId; std::string _newBlockContent; }; #endif // _XJSONBLOBREQUEST_HH_ // vim: se ai sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobSink.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XJsonBlobSink.hh" #include "XJsonBlobRequest.hh" #include "XJsonBlobBlockCountsMgr.hh" #include #include #include "CanonicalEntity.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "Utility.hh" #include "AzureUtility.hh" #include "RowIndex.hh" #include "Trace.hh" #include "Logger.hh" #include "MdsdMetrics.hh" #include "StoreType.hh" #include "Constants.hh" #include "MdsTime.hh" #include "CfgOboDirectConfig.hh" #include XJsonBlobSink::XJsonBlobSink(MdsdConfig* config, const MdsEntityName &target, const Credentials* c) : IMdsSink(StoreType::Type::XJsonBlob), _template(target), _creds(c) , _namespace(config->Namespace()), _firstReqCreated(false) , _blobBaseTime(0) // Make sure to set _blobBaseTime to a past long time ago when constructing { Trace trace(Trace::JsonBlob, "XJBS::Constructor"); if (!config) { throw std::invalid_argument("Null MdsdConfig* config"); } if (!c) { throw std::invalid_argument("Null Credentials* c"); } auto eventName = target.EventName(); try { auto oboDirectConfig = config->GetOboDirectConfig(eventName); // May throw std::out_of_range if eventName is not a key stored in the map. InitializeForOboDirect(config, oboDirectConfig); } catch (const std::out_of_range& e) { // No OboDirect config. It's LAD JsonBlob sink scenario. InitializeForLadWithoutOboDirect(config); } // Finally, fill in duration/tenant/role/roleInstance (for metric Json content) _template.duration = config->GetDurationForEventName(eventName); config->GetIdentityValues(_template.tenant, _template.role, _template.roleInstance); XJsonBlobBlockCountsMgr::GetInstance().CreatePersistDirIfNotDone(); } static void AppendBlobPathComponent( const std::string& fieldName, const std::string& fieldNameInBlobPath, MdsdConfig* config, std::string& blobPathComponentString) { if (fieldName.empty()) { throw std::invalid_argument("AppendBlobPathComponent(): fieldName cannot be empty"); } auto fieldValue = config->GetOboDirectPartitionFieldValue(fieldName); if (fieldValue.empty()) { std::string msg = "No CentralJson blob path field value found for field name " + fieldName + ". Make sure that your mdsd config XML contains " "OboDirectPartitionField element with the corresponding field name " "attribute in Management/Identity section."; Logger::LogError(msg); throw std::runtime_error(msg); } if (!blobPathComponentString.empty()) { blobPathComponentString.append("/"); } blobPathComponentString.append(fieldNameInBlobPath).append("=").append(fieldValue); } void XJsonBlobSink::InitializeForOboDirect(MdsdConfig* config, const std::shared_ptr& oboDirectConfig) { _blobIntervalISO8601Duration = oboDirectConfig->timePeriods; _blobIntervalSec = MdsTime::FromIS8601Duration(_blobIntervalISO8601Duration).to_time_t(); if (0 == _blobIntervalSec) { //Logger::LogError("Invalid ISO8601 duration (" + blobIntervalISO8601Duration + ") given. This shouldn't happen. Default 'PT1H' will be used."); _blobIntervalSec = 60*60; // 1 hour _blobIntervalISO8601Duration = "PT1H"; } const auto& primaryPartitionFieldName = oboDirectConfig->primaryPartitionField; // handy reference if (!primaryPartitionFieldName.empty()) { // Compose primaryPartitionField (e.g., "name1=xxx") AppendBlobPathComponent(primaryPartitionFieldName, primaryPartitionFieldName, config, _template.primaryPartitionField); } if (!oboDirectConfig->partitionFields.empty()) { // Compose partitionFields (e.g., "name1=xxx/name2=yyy") std::istringstream iss(oboDirectConfig->partitionFields); // oboDirectConfig.partitionFields is e.g., 'name1,name2' while (iss.good()) { std::string partitionFieldName; getline(iss, partitionFieldName, ','); if (!partitionFieldName.empty()) { AppendBlobPathComponent(partitionFieldName, partitionFieldName, config, _template.partitionFields); } } } } void XJsonBlobSink::InitializeForLadWithoutOboDirect(MdsdConfig* config) { // LAD JsonBlob's interval is fixed to 1 hour. _blobIntervalSec = 60*60; // 1 hour _blobIntervalISO8601Duration = "PT1H"; AppendBlobPathComponent("resourceId", "resourceId", config, _template.primaryPartitionField); AppendBlobPathComponent("agentIdentityHash", "i", config, _template.agentIdentityHash); } void XJsonBlobSink::ComputeConnString() { Trace trace(Trace::JsonBlob, "XJBS::ComputeConnString"); const MdsEntityName& Target = _template.target; // Easy to use reference // This is pretty easy for XJsonBlob; we currently support shared-key creds only. // expires & eventName don't apply to XJsonBlob (at least yet), so just dummy vars passed. MdsTime expires; std::string eventName; if (_creds->ConnectionString(Target, Credentials::ServiceType::Blob, eventName, _connString, expires) ) { TRACEINFO(trace, Target << "=[" << _connString << "] expires " << expires << "(N/A for XJsonBlob)"); } else { Logger::LogError("Error: Couldn't construct connection string for XJsonBlob eventName " + Target.Basename()); } } // The only credentials that need to be validated are "Shared key" or an account SAS; if we have a service SAS or Autokey, // we'll find out if they work when we try to use them. We can validate shared key credentials // by creating the container for the eventName, if it doesn't already exist. Since this gets // called only during config load, it's reasonable to perform the operation synchronously. void XJsonBlobSink::ValidateAccess() { Trace trace(Trace::JsonBlob, "XJBS::ValidateAccess"); auto sasCreds = dynamic_cast(_creds); if (_creds->Type() == Credentials::SecretType::Key || (sasCreds && sasCreds->IsAccountSas())) { ComputeConnString(); // Force computation, since this is called at config time // "Container name will be the concatenation of namespace, event name, and event version if present." // "For example: obodirectnamespacetestevent1ver2v0" // from https://microsoft.sharepoint.com/teams/SPS-AzMon/Shared Documents/Design Documents/Direct Mode Design.docx?web=1 _containerName = MdsdUtil::to_lower(_namespace + _template.target.Basename()); // Azure Storage allows lowercase only in container name MdsdUtil::CreateContainer(_connString, _containerName); } } XJsonBlobSink::~XJsonBlobSink() { Trace trace(Trace::JsonBlob, "XJBS::Destructor"); } // Convert the CanonicalEntity to Json and add it to the accumulated buffer. Flush it // if it fills up. // // Note that AddRow() doesn't keep the CanonicalEntity; we copy anything we need from it. void XJsonBlobSink::AddRow(const CanonicalEntity &row, const MdsTime& qibase) { Trace trace(Trace::JsonBlob, "XJBS::AddRow"); TRACEINFO(trace, "containerName = " << _containerName << ", blob basetime = " << _blobBaseTime << ", blob interval (sec) = " << _blobIntervalSec << ", qibase = " << qibase); // If the query interval is beyond blob base time + blob interval, // we should flush the current block and reset the base time accordingly. if (qibase >= _blobBaseTime + _blobIntervalSec) { Flush(); _blobBaseTime = qibase.Round(_blobIntervalSec); // Make sure to round down to the specified blob interval _blockList.reset(); TRACEINFO(trace, "New blob basetime = " << _blobBaseTime); } // If we have no in-progress request, either because we just flushed or because we're just // starting up, make one. if (!_request) { try { std::string requestId = utility::uuid_to_string(utility::new_uuid()); if (!_blockList) { _blockList = std::make_shared(); } _request.reset(new XJsonBlobRequest(_template, _blobBaseTime, _blobIntervalISO8601Duration, _containerName, requestId, _blockList)); // This is the only place we create any XJBReq, so we must check if this is the first time // to see if we need to try to reconstruct the block list from a persisted block count file. // If there are other places where XJBReq is created, this must be done there as well... if (!_firstReqCreated) { XJsonBlobRequest::ReconstructBlockListIfNeeded(_request); _firstReqCreated = true; } } catch (std::exception & ex) { std::ostringstream msg; msg << "Exception (" << ex.what() << ") caught while creating new XJsonBlobRequest; dropping row"; trace.NOTE(msg.str()); Logger::LogError(msg.str()); MdsdMetrics::Count("Dropped_Entities"); return; } } // The XJsonBlobRequest object stores generated json rows that // correspond to the generated rows. The object also contains the metadata needed to // determine the name of the blob when it gets written. (This includes a sequence number; // if the blob fills, we flush it and start accumulating a new one with an // incremented sequence.) TRACEINFO(trace, "Adding row to request ID " << _request->UUID() << ": " << row); std::string jsonRow; try { jsonRow = row.GetJsonRow(_template.duration, _template.tenant, _template.role, _template.roleInstance); } catch (std::exception& e) { Logger::LogError(e.what()); return; } _request->AddJsonRow(std::move(jsonRow)); TRACEINFO(trace, "Block now contains " << _request->EstimatedSize() << " bytes"); // If the size of the accumulated data is "close" to the maximum size of a JSON blob block, // flush the block and prepare for the next one if (_request->EstimatedSize() > _targetBlockSize) { TRACEINFO(trace, "Size of accumulated rows is larger than block size limit; flushing"); Flush(); } } // Flush any data we're holding. We might never have allocated a request, or it might // be empty, or we might have data. // Post-condition: _request is nullptr. Next call to AddRow() will create a new request on demand. void XJsonBlobSink::Flush() { Trace trace(Trace::JsonBlob, "XJBS::Flush"); TRACEINFO(trace, "Begin XJBS::Flush on containerName = " << _containerName); if (nullptr == _request) { // First time through. Just make the post-condition true TRACEINFO(trace, "Null _request; no action."); return; } // XJsonBlob must flush if there's any data. Otherwise, just return. if (_request->EstimatedSize() == 0) { TRACEINFO(trace, "No data to flush; no action."); return; } TRACEINFO(trace, "Flush() request ID " + _request->UUID()); if (_request->EstimatedSize() > 0) { // Detach the request and send it. Send() is fire-and-forget; the request object // is responsible for deleting itself after that point. try { XJsonBlobRequest::Send(std::move(_request), _connString); } catch (std::exception & ex) { trace.NOTE(std::string("Exception leaked from XJBR Send: ") + ex.what()); } } else { // Since we create these on demand, this really shouldn't happen. TRACEINFO(trace, "Empty _request; no action (deleting)."); _request.reset(); } } // vim: se sw=4 expandtab ts=4 : ================================================ FILE: Diagnostic/mdsd/mdsd/XJsonBlobSink.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _XJSONBLOBSINK_HH_ #define _XJSONBLOBSINK_HH_ #include "IMdsSink.hh" #include #include "MdsTime.hh" #include "MdsEntityName.hh" #include #include #include class CanonicalEntity; class Credentials; class MdsdConfig; class MdsValue; class XJsonBlobRequest; namespace mdsd { struct OboDirectConfig; } namespace azure { namespace storage { class block_list_item; }} // Thin object wrapper, supporting synchronization across async task threads (with owner name) template class ObjectWithOwnership { public: ObjectWithOwnership() {} T& get() { return _object; } void LockIfOwnedByNoneThenSetOwner(const std::string& ownerName) { if (ownerName.empty()) { throw std::invalid_argument("Passed ownerName is empty in ObjectWithOwnership::LockIfOwnedByNoneThenSetOwner"); } std::unique_lock lock(_mutex); _cv.wait(lock, [this]{ return _ownerName.empty(); }); _ownerName = ownerName; } // Caller must make sure that the set owner is itself. void ResetOwnerAndNotify() { std::lock_guard lock(_mutex); if (_ownerName.empty()) { throw std::runtime_error("Current _ownerName is empty in ObjectWithOwnership::ResetOwnerAndNotify"); } _ownerName.clear(); _cv.notify_all(); } private: T _object; std::mutex _mutex; std::string _ownerName; std::condition_variable _cv; }; using BlockListT = ObjectWithOwnership>; class XJsonBlobSink : public IMdsSink { public: struct RequestInfo { public: const MdsEntityName target; // Destination storage container std::string primaryPartitionField; // E.g., "resourceId=...". 'resourceId' is obtained from OboDirectConfig.primaryPartitionField, // and '...' needs to be obtained from somewhere else (Portal/LAD config? -- WAD is blocked on this) std::string agentIdentityHash; std::string partitionFields; // E.g., "resourceId=xxx/subscriptionId=yyy". 'resourceId' and 'subscriptionId' are obtained from OboDirectConfig.partitionFields, // and 'xxx' and 'yyy' need to be obtained from somewhere else (OBO service? What about LAD scenario?) std::string duration; // E.g., "PT1M" for metric events. "" for non-metric events. Will be used by Json construction std::string tenant; // Tenane name in metric Json content std::string role; // Role name in metric Json content std::string roleInstance; // RoleInstance name in metric Json content RequestInfo(const MdsEntityName& t) : target(t) {} }; virtual bool IsXJsonBlob() const { return true; } XJsonBlobSink(MdsdConfig* config, const MdsEntityName &target, const Credentials* c); virtual ~XJsonBlobSink(); virtual void AddRow(const CanonicalEntity&, const MdsTime&); virtual void Flush(); virtual void ValidateAccess(); private: XJsonBlobSink(); // This code path is currently really not used (as we haven't actually // implemented the OboDirect feature), but just placed for the future. void InitializeForOboDirect(MdsdConfig* config, const std::shared_ptr& oboDirectConfig); // This will be mostly used for LAD JsonBlob sink scenario. void InitializeForLadWithoutOboDirect(MdsdConfig* config); void ComputeConnString(); RequestInfo _template; const Credentials* _creds; std::string _namespace; std::string _containerName; std::shared_ptr _request; // Per-blob block list that needs to be persisted across multiple requests, // so keep it here as a shared ptr. XJBS just maintains a pointer (so that // it can be persisted across multiple requests) and all operations on it // are done by XJBR. std::shared_ptr _blockList; // Block list reconstruction from a persisted block count file is needed // only for the first request, so remember whether first request was created or not. bool _firstReqCreated; MdsTime _blobBaseTime; // Base time for which we're currently building a blob. time_t _blobIntervalSec; // E.g., 1 hour (3600 sec). Fixed interval in seconds for a blob. std::string _blobIntervalISO8601Duration; // E.g., "PT1H". _blobIntervalSec should be computed from this. If this is not a correct ISO8601 string, it should be "PT1H" by default. // Maintained by ComputeConnString() std::string _connString; // Other constants static constexpr size_t _targetBlockSize { 4128768 }; // 4MB - 64KB }; #endif // _XJSONBLOBSINK_HH_ // vim: se sw=4 : ================================================ FILE: Diagnostic/mdsd/mdsd/XTableConst.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XTableConst.hh" unsigned int XTableConstants::_backoffBaseTime = 10; unsigned int XTableConstants::_backoffLimit = 3; int XTableConstants::_sdkRetryPolicyInterval = 3; int XTableConstants::_sdkRetryPolicyLimit = 5; int XTableConstants::_initialOpTimeout = 30; int XTableConstants::_defaultOpTimeout = 30; ================================================ FILE: Diagnostic/mdsd/mdsd/XTableConst.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _XTABLECONST_HH_ #define _XTABLECONST_HH_ // "Constants" used by XTableSink, DataUploader, etc. // These are generally run-time constants. They've been encapsulated in this class so they // can be manipulated at run-time by test code, generally to reduce timeouts or retry counts. class XTableConstants { public: // Getters static unsigned int BackoffBaseTime() { return _backoffBaseTime; } static unsigned int BackoffLimit() {return _backoffLimit; } static int SDKRetryPolicyInterval() { return _sdkRetryPolicyInterval; } static int SDKRetryPolicyLimit() { return _sdkRetryPolicyLimit; } static int InitialOpTimeout() { return _initialOpTimeout; } static int DefaultOpTimeout() { return _defaultOpTimeout; } static unsigned int MaxItemPerBatch() { return 100; } // Not alterable // Setters static void SetBackoffBaseTime(unsigned int val) { _backoffBaseTime = val; } static void SetBackoffLimit(unsigned int val) { _backoffLimit = val; } static void SetSDKRetryPolicyInterval(int val) { _sdkRetryPolicyInterval = val; } static void SetSDKRetryPolicyLimit(int val) { _sdkRetryPolicyLimit = val; } static void SetInitialOpTimeout(int val) { _initialOpTimeout = val; } static void SetDefaultOpTimeout(int val) { _defaultOpTimeout = val; } private: XTableConstants(); XTableConstants(const XTableConstants&) = delete; static unsigned int _backoffBaseTime; static unsigned int _backoffLimit; static int _sdkRetryPolicyInterval; static int _sdkRetryPolicyLimit; static int _initialOpTimeout; static int _defaultOpTimeout; }; #endif // _XTABLECONST_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XTableHelper.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XTableHelper.hh" #include "XTableConst.hh" #include "Logger.hh" #include "Trace.hh" #include "Utility.hh" using namespace azure::storage; using std::string; XTableHelper* XTableHelper::GetInstance() { static XTableHelper* s_instance = new XTableHelper(); return s_instance; } XTableHelper::XTableHelper() { } // Delete all the cloud_table objects stored in the cache map. XTableHelper::~XTableHelper() { Trace trace(Trace::XTable, "XTableHelper destructor"); try { cloudTableMap.clear(); } catch(const std::exception& e) { LogError("Error: ~XTableHelper(): unexpected std::exception: " + string(e.what())); } catch(...) { } } /* Each new table object is saved in the hash table. Because the hash table will be shared by multiple threads, using mutex lock for table operations. The tablename is the actual table name, not a URI. Ie the connStr specifies a SAS, then Azure requires the tablename to match the tn= component in the SAS; otherwise, either fetching the table reference will fail (here) or using it will fail (probably by the caller of this function). */ std::shared_ptr XTableHelper::CreateTable(const string& tablename, const string& connStr) { Trace trace(Trace::XTable, "XTableHelper::CreateTable"); std::shared_ptr tableObj; try { trace.NOTE("tablename='" + tablename + "'; connection string='" + connStr + "'."); if (MdsdUtil::NotValidName(tablename)) { LogError("Error: invalid table name: '" + tablename + "'; connection string='" + connStr + "'."); return nullptr; } if (MdsdUtil::NotValidName(connStr)) { LogError("Error: invalid connection string: '" + connStr + "'. tablename='" + tablename + "'"); return nullptr; } const auto key = connStr + tablename; std::lock_guard lock(tablemutex); auto ctIter = cloudTableMap.find(key); if (ctIter != cloudTableMap.end()) { trace.NOTE("Found table object in cache. tablename='" + tablename + "'"); tableObj = ctIter->second; } else { trace.NOTE("Create new cloud_table for '" + tablename + "' with connection string='" + connStr + "'."); tableObj = std::make_shared( cloud_storage_account::parse(connStr) .create_cloud_table_client() .get_table_reference(tablename) ); cloudTableMap[key] = tableObj; } } catch(const std::exception& e) { LogError("Error: XTableHelper::CreateTable(" + tablename + "): unexpected std::exception: " + string(e.what()) ); } catch(...) { LogError("Error: XTableHelper::CreateTable(" + tablename + "): unexpected exception"); } return tableObj; } /* Handle storage exception. Return true if the execution can be retried. Return false if no value to retry. By default, not report error when an exception occurs. The default behavior is to retry, except the following cases, where the HTTP status code is: - BadRequest 400: bad API request. report error. (ex: bad credential) - NotFound 404 : table not found. report error. - Forbidden 403 : permission denied. report error. - Conflict 409 : data already uploaded or duplicates found. If this is the first time, report error. if not first time, shouldn't report error. */ bool XTableHelper::HandleStorageException(const string & tablename, const storage_exception& e, size_t * pnerrs, bool isFirstTime, bool * isNoSuchTable) { Trace trace(Trace::XTable, "HandleStorageException"); bool retryableErr = true; bool suppressErrorMsg = false; trace.NOTE(std::string("Storage exception: ") + e.what()); auto msg = std::string(e.what()) + "\n"; request_result result = e.result(); storage_extended_error err = result.extended_error(); if (!err.message().empty()) { msg += err.message(); trace.NOTE("Extended info: " + err.message()); } // the retryable API is not accurate (ex for a client timeout, which retry may work, but retryable // is still false. so not use it as of 10/17/14.) // bool retryable1 = e.retryable(); // msg += ustring("exception is retryable? = ") + std::to_string(retryable1); web::http::status_code httpcode = result.http_status_code(); trace.NOTE("HTTP status " + std::to_string(httpcode)); msg += "\nStatusCode=" + std::to_string(httpcode); bool isErr = false; { using web::http::status_codes; if (httpcode == status_codes::NotFound && isNoSuchTable) { *isNoSuchTable = true; // By handing us a valid isNoSuchTable ptr, caller has indicated // a desire to handle the No Such Table error directly. suppressErrorMsg = true; } if (httpcode == status_codes::NotFound || httpcode == status_codes::BadRequest || httpcode == status_codes::Forbidden) { isErr = true; } else if (httpcode == status_codes::Conflict) { retryableErr = false; if (isFirstTime) { isErr = true; } } } if (isErr) { if (!suppressErrorMsg) { LogError("Azure Storage Exception for table \"" + tablename + "\": " + msg); } retryableErr = false; if (pnerrs) (*pnerrs)++; } if (!retryableErr) { trace.NOTE("Status code " + std::to_string(httpcode) + " is not retryable. Abort further retry."); } return retryableErr; } void XTableHelper::CreateRequestOperation(table_request_options& requestOpt) const { exponential_retry_policy retry_policy( std::chrono::seconds(XTableConstants::SDKRetryPolicyInterval()), XTableConstants::SDKRetryPolicyLimit()); requestOpt.set_retry_policy(retry_policy); requestOpt.set_server_timeout(std::chrono::seconds(XTableConstants::DefaultOpTimeout())); requestOpt.set_maximum_execution_time(std::chrono::seconds(XTableConstants::InitialOpTimeout())); requestOpt.set_payload_format(table_payload_format::json_no_metadata); } void XTableHelper::CreateOperationContext(operation_context& c) const { std::string id = utility::uuid_to_string(utility::new_uuid()); c.set_client_request_id(id); } /* Error level logging. This is to isolate XTableHelper logging. */ void XTableHelper::LogError(const std::string & msg) const { auto msg2 = MdsdUtil::GetTid() + ": " + msg; Logger::LogError(msg2); } // vim: se sw=4 : // Would prefer 8, but... c'est la vie ================================================ FILE: Diagnostic/mdsd/mdsd/XTableHelper.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _XTABLEHELPER_HH_ #define _XTABLEHELPER_HH_ #include #include #include "was/storage_account.h" #include "was/table.h" #include #include /* This class manages storage table operations. Because same event can be uploaded to multiple tables, instead of creating new table, uploading, then deleting it, each new table object will be saved in cache. All tables are freed at final destruction time. */ class XTableHelper { public: static XTableHelper* GetInstance(); // disable copy and move contructors XTableHelper(XTableHelper&& h) = delete; XTableHelper& operator=(XTableHelper&& h) = delete; XTableHelper(const XTableHelper&) = delete; XTableHelper& operator=(const XTableHelper &) = delete; // create a new cloud table using connection string (ex: AccountName/Key, or SAS Key) // The table will be stored in a cache for future fast reference. // tablename: the actual tablename, not a URI. std::shared_ptr CreateTable(const std::string& tablename, const std::string& connStr); // Handle storage exception. Return true if the execution can be retried. Return false // if no value to retry. Return the number of errors found by pnerrs. If isFirstTry is // true, it means this is the first time to run the upload operation on this dataset. // Only updates *pnerrs and *isNoSuchTable if those pointers are not nullptr. bool HandleStorageException(const std::string& tablename, const azure::storage::storage_exception& e, size_t * pnerrs, bool isFirstTry, bool * isNoSuchTable); // Create a new request operation object. void CreateRequestOperation(azure::storage::table_request_options& options) const; // Create a new operation context object. void CreateOperationContext(azure::storage::operation_context & context) const; private: XTableHelper(); ~XTableHelper(); // Log error message. This function is to make isolated test easiler. void LogError(const std::string& msg) const; // This will store all the created cloud_table objects. Key=tableUri; std::unordered_map> cloudTableMap; std::mutex tablemutex; }; #endif // _XTABLEHELPER_HH_ ================================================ FILE: Diagnostic/mdsd/mdsd/XTableRequest.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XTableRequest.hh" #include "XTableConst.hh" #include "Logger.hh" #include "Trace.hh" #include "XTableHelper.hh" #include "MdsdMetrics.hh" #include #include XTableRequest::XTableRequest(const std::string& connStr, const std::string& tablename) : _tablename(tablename), _rowCount(0) { Trace trace(Trace::XTable, "XTR Constructor"); auto helper = XTableHelper::GetInstance(); _table = helper->CreateTable(tablename, connStr); if (!_table) { std::ostringstream msg; msg << "CreateTable(" << tablename << ", '" << connStr << "') returned nullptr"; trace.NOTE(msg.str()); throw std::runtime_error(msg.str()); } helper->CreateRequestOperation(_requestOptions); helper->CreateOperationContext(_context); _useUpsert = (tablename == "SchemasTable"); // Ugh - such a hack } bool XTableRequest::AddRow(const azure::storage::table_entity & row) { Trace trace(Trace::XTable, "XTR::AddRow"); if (_rowCount == XTableConstants::MaxItemPerBatch()) { trace.NOTE("Batch is already full; ignoring row"); return false; } if (_useUpsert) { _batchOperation.insert_or_replace_entity(row); } else { _batchOperation.insert_entity(row); } _rowCount++; return true; } /*static*/ void XTableRequest::Send(std::unique_ptr req) { Trace trace(Trace::XTable, "XTR::Send"); req->_rowCount = req->_batchOperation.operations().size(); MdsdMetrics::Count("XTable_send"); MdsdMetrics::Count("XTable_rowsSent", req->_rowCount); if (req->_rowCount == 0) { trace.NOTE("Shortcut completion: zero row count"); return; } // Need to convert the unique_ptr to shared_ptr for lambda capture inside XTableRequest::DoWork(std::shared_ptr(req.release()), boost::system::error_code()); } /*static*/ void XTableRequest::DoWork(std::shared_ptr req, const boost::system::error_code &error) { Trace trace(Trace::XTable, "XTR::DoWork"); if (error) { std::ostringstream msg; msg << "DoWork() observed error " << error << " from previous task"; trace.NOTE(msg.str()); Logger::LogError(msg.str()); return; } req->_table->execute_batch_async(req->_batchOperation, req->_requestOptions, req->_context) .then([req](pplx::task > t) { DoContinuation(req, t); }) .then([=](pplx::task previous_task) { try { previous_task.wait(); } catch (std::exception & e) { MdsdMetrics::Count("XTable_failedGeneralException"); std::ostringstream msg; msg << "Writing to table '" << req->_tablename << "' " << "caught exception: " << e.what(); Logger::LogError("XTR::DoWork(): " + msg.str()); } catch(...) { MdsdMetrics::Count("XTable_failedUnknownException"); Logger::LogError("XTR::DoWork() caught unknown exception."); } }); } /*static*/ void XTableRequest::DoContinuation(std::shared_ptr req, pplx::task > t) { Trace trace(Trace::XTable, "XTR::DoContinuation"); size_t errcount = 0; try { t.wait(); for (const auto &result : t.get() ) { if (result.http_status_code() != web::http::status_codes::NoContent) { std::ostringstream msg; msg << "Unexpected HTTP status " << result.http_status_code() << " when writing to " << req->_tablename; trace.NOTE(msg.str()); Logger::LogError(msg.str()); errcount++; } } if (errcount) { std::ostringstream msg; msg << "Total of " << errcount << ((errcount==1)?"error":"errors") << " while writing to " << req->_tablename; Logger::LogError(msg.str()); MdsdMetrics::Count("XTable_completeWithErrors"); trace.NOTE("Completed but some rows not successful"); } else { MdsdMetrics::Count("XTable_complete"); trace.NOTE("Complete"); } MdsdMetrics::Count("XTable_rowsSuccess", req->_rowCount - std::min(errcount, req->_rowCount)); } catch (azure::storage::storage_exception & e) { trace.NOTE("Caught storage exception for table " + req->_tablename); bool isNoSuchTable = false; XTableHelper::GetInstance()->HandleStorageException(req->_tablename, e, &errcount, true, &isNoSuchTable); if (isNoSuchTable) { // Table doesn't exist. Let's see if we can create it. trace.NOTE("Trying to create table " + req->_tablename); MdsdMetrics::Count("XTable_tableCreate"); req->_table->create_if_not_exists_async(req->_requestOptions, req->_context) .then([req](pplx::task t) { Trace trace(Trace::XTable, "XTR Create Table lambda"); try { t.wait(); (void) t.get(); // Don't care if it was already created // If we get here, the table exists; let's retry the initial operation MdsdMetrics::Count("XTable_retries"); XTableRequest::DoWork(req, boost::system::error_code()); return; } catch (azure::storage::storage_exception & e) { // Just emit the necessary error messages (void)XTableHelper::GetInstance() ->HandleStorageException(req->_tablename, e, nullptr, true, nullptr); } catch (std::exception& e) { std::string msg = "While trying to create table " + req->_tablename + " Caught exception: " + e.what(); trace.NOTE(msg); Logger::LogError(msg); } catch (...) { std::string msg = "While trying to create table " + req->_tablename + " Caught unknown exception."; trace.NOTE(msg); Logger::LogError(msg); } }); return; } } catch (std::exception & e) { MdsdMetrics::Count("XTable_failedGeneralException"); std::ostringstream msg; msg << "Caught exception: " << e.what(); trace.NOTE(msg.str()); Logger::LogError("XTR::DoContinuation(): " + msg.str()); } catch (...) { MdsdMetrics::Count("XTable_failedUnknownException"); trace.NOTE("Caught unknown exception."); Logger::LogError("XTR::DoContinuation() caught unknown exception."); } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XTableRequest.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _XTABLEREQUEST_HH_ #define _XTABLEREQUEST_HH_ #include #include #include #include #include #include #include class XTableRequest { public: XTableRequest(const std::string& connStr, const std::string& tablename); bool AddRow(const azure::storage::table_entity &row); static void Send(std::unique_ptr req); size_t Size() { return _rowCount; } private: std::shared_ptr _table; std::string _tablename; azure::storage::table_batch_operation _batchOperation; azure::storage::table_request_options _requestOptions; azure::storage::operation_context _context; size_t _rowCount; bool _useUpsert; static void DoWork(std::shared_ptr req, const boost::system::error_code&); static void DoContinuation(std::shared_ptr req, pplx::task > t); }; #endif // _XTABLEREQUEST_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XTableSink.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "XTableSink.hh" #include #include #include "CanonicalEntity.hh" #include "Engine.hh" #include "MdsdConfig.hh" #include "Credentials.hh" #include "Utility.hh" #include "RowIndex.hh" #include "Trace.hh" #include "MdsdMetrics.hh" #include "XTableRequest.hh" #include "StoreType.hh" #include "stdafx.h" #include "was/table.h" #include "was/common.h" using std::string; using azure::storage::entity_property; XTableSink::XTableSink(MdsdConfig* config, const MdsEntityName &target, const Credentials* c) : IMdsSink(StoreType::Type::XTable), _config(config), _target(target), _creds(c) { Trace trace(Trace::XTable, "XTS::Constructor"); if (!target.IsSchemasTable()) { // Build the identity columns metadata only once // Similarly, compute the partition data only once. // SchemasTable has no identity columns (in this sense) and does partitioning differently config->GetIdentityColumnValues(std::back_inserter(_identityColumns)); std::vector identValues; identValues.reserve(_identityColumns.size()); for (const ident_col_t& idpair : _identityColumns) { identValues.push_back(idpair.second); } _identColumnString = MdsdUtil::Join(identValues, "___"); unsigned long long N = MdsdUtil::EasyHash(_identColumnString) % (unsigned long long)(_config->PartitionCount()); _N = MdsdUtil::ZeroFill(N, 19); } _estimatedBytes = 0; } void XTableSink::ComputeConnString() { Trace trace(Trace::XTable, "XTS::ComputeConnString"); if (_creds->ConnectionString(_target, Credentials::ServiceType::XTable, _fullTableName, _connString, _rebuildTime) ) { if (trace.IsActive()) { std::ostringstream msg; msg << _fullTableName << "=[" << _connString << "] expires " << _rebuildTime; trace.NOTE(msg.str()); } } else { Logger::LogError("Couldn't construct connection string for table " + _target.Name()); } } XTableSink::~XTableSink() { Trace trace(Trace::XTable, "XTS::Destructor"); } // Convert the CanonicalEntity to a table_entity and add it to our internal request. Flush // the request if it fills up. // // Note that AddRow() doesn't keep the CanonicalEntity; we copy anything we need from it. void XTableSink::AddRow(const CanonicalEntity &row, const MdsTime& qibase) { Trace trace(Trace::XTable, "XTS::AddRow"); // If this row is for a different partition, flush what we have and track the new partition if (row.PartitionKey() != _pkey) { Flush(); _pkey = row.PartitionKey(); } // If we have no in-progress request, either because we just flushed or because we're just // starting up, make one. if (! _request) { try { ComputeConnString(); _request.reset(new XTableRequest(_connString, _fullTableName)); } catch (std::exception &ex) { std::ostringstream msg; msg << "Exception (" << ex.what() << ") caught while creating new XTableRequest; dropping row"; trace.NOTE(msg.str()); Logger::LogError(msg.str()); MdsdMetrics::Count("Dropped_Entities"); return; } } azure::storage::table_entity e { _pkey, row.RowKey() }; azure::storage::table_entity::properties_type& properties = e.properties(); size_t byteCount = 2 * (_pkey.length() + row.RowKey().length()) + 4; bool oversize = false; for (const auto & col : row) { // col is pair auto namesize = 2 * col.first.length(); byteCount += namesize; // Account for the column name, which is stored in the entity in XStore switch((col.second)->type) { case MdsValue::MdsType::mt_bool: properties[col.first] = entity_property((col.second)->bval); byteCount += 1; break; case MdsValue::MdsType::mt_wstr: { properties[col.first] = entity_property(*((col.second)->strval)); auto colsize = 2 * ((col.second)->strval->length()) + 2; byteCount += colsize; if (colsize + namesize > 65536) { // XStore max attribute size is 64Ki std::ostringstream msg; msg << "Column " << col.first << " oversize: colsize " << colsize << " namesize " << namesize; trace.NOTE(msg.str()); oversize = true; } } break; case MdsValue::MdsType::mt_float64: properties[col.first] = entity_property((col.second)->dval); byteCount += 8; break; case MdsValue::MdsType::mt_int32: properties[col.first] = entity_property((int32_t)(col.second)->lval); byteCount += 4; break; case MdsValue::MdsType::mt_int64: properties[col.first] = entity_property((int64_t)(col.second)->llval); byteCount += 8; break; case MdsValue::MdsType::mt_utc: properties[col.first] = entity_property((col.second)->datetimeval); byteCount += 8; break; } } if (oversize || (byteCount > 1024*1024)) { // XStore max table size is 1024Ki trace.NOTE("Entity or column too large - dropped"); std::ostringstream msg; msg << "Dropping oversize entity: " << row; Logger::LogWarn(msg.str()); MdsdMetrics::Count("Dropped_Entities"); MdsdMetrics::Count("Overlarge_Entities"); return; } if ((_estimatedBytes + byteCount) > 4000000) { trace.NOTE("Batch would be too big; flushing before adding this entity"); Flush(); try { ComputeConnString(); _request.reset(new XTableRequest(_connString, _fullTableName)); } catch (std::exception & ex) { std::ostringstream msg; msg << "Exception (" << ex.what() << ") caught while creating new XTableRequest; dropping row"; trace.NOTE(msg.str()); Logger::LogError(msg.str()); MdsdMetrics::Count("Dropped_Entities"); return; } } _request->AddRow(e); _estimatedBytes += byteCount; if (trace.IsActive()) { std::ostringstream msg; msg << "We have " << _request->Size() << " rows"; trace.NOTE(msg.str()); } if (_request->Size() == 100) { Flush(); } } // Flush any data we're holding. We might never have allocated a request, or it might // be empty, or we might have data. // Post-condition: _request is nullptr. Next call to AddRow() will create a new request on demand. void XTableSink::Flush() { Trace trace(Trace::XTable, "XTS::Flush"); if (!_request) { // First time through. Just make the post-condition true trace.NOTE("Null _request; no action."); } else { if (_request->Size() > 0) { // Detach the request and send it. Send() is fire-and-forget; the request object // is responsible for deleting itself after that point. trace.NOTE("Writing to " + _fullTableName + " with connection string " + _connString); XTableRequest::Send(std::move(_request)); } else { // Since we create these on demand, this really shouldn't happen. trace.NOTE("Empty _request; no action (deleting)."); } _request.reset(); _estimatedBytes = 0; } } // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/XTableSink.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _XTABLESINK_HH_ #define _XTABLESINK_HH_ #include "IMdsSink.hh" #include #include #include #include "stdafx.h" #include "IdentityColumns.hh" #include "MdsTime.hh" #include "MdsEntityName.hh" class CanonicalEntity; class Credentials; class MdsdConfig; class XTableRequest; class XTableSink : public IMdsSink { public: virtual bool IsXTable() const { return true; } XTableSink(MdsdConfig* config, const MdsEntityName &target, const Credentials* c); virtual ~XTableSink(); virtual void AddRow(const CanonicalEntity&, const MdsTime&); virtual void Flush(); private: XTableSink(); void ComputeConnString(); MdsdConfig* _config; MdsEntityName _target; const Credentials* _creds; ident_vect_t _identityColumns; std::string _identColumnString; MdsTime _QIBase; std::string _pkey; std::string _TIMESTAMP; std::string _N; std::string _connString; std::string _fullTableName; MdsTime _rebuildTime; std::unique_ptr _request; unsigned long _estimatedBytes; }; #endif // _XTABLESINK_HH_ // vim: se sw=8 : ================================================ FILE: Diagnostic/mdsd/mdsd/cJSON.c ================================================ /* Copyright (c) 2009 Dave Gamble Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /* cJSON */ /* JSON parser in C. */ //#define _GNU_SOURCE #include #include #include #include #include #include #include #include "cJSON.h" static const char *ep; const char *cJSON_GetErrorPtr(void) {return ep;} static int cJSON_strcasecmp(const char *s1,const char *s2) { if (!s1) return (s1==s2)?0:1;if (!s2) return 1; for(; tolower(*s1) == tolower(*s2); ++s1, ++s2) if(*s1 == 0) return 0; return tolower(*(const unsigned char *)s1) - tolower(*(const unsigned char *)s2); } static void *(*cJSON_malloc)(size_t sz) = malloc; static void (*cJSON_free)(void *ptr) = free; static char* cJSON_strdup(const char* str) { size_t len; char* copy; len = strlen(str) + 1; if (!(copy = (char*)cJSON_malloc(len))) return 0; memcpy(copy,str,len); return copy; } void cJSON_InitHooks(cJSON_Hooks* hooks) { if (!hooks) { /* Reset hooks */ cJSON_malloc = malloc; cJSON_free = free; return; } cJSON_malloc = (hooks->malloc_fn)?hooks->malloc_fn:malloc; cJSON_free = (hooks->free_fn)?hooks->free_fn:free; } /* Internal constructor. */ static cJSON *cJSON_New_Item(void) { cJSON* node = (cJSON*)cJSON_malloc(sizeof(cJSON)); if (node) memset(node,0,sizeof(cJSON)); return node; } /* Delete a cJSON structure. */ void cJSON_Delete(cJSON *c) { cJSON *next; while (c) { next=c->next; if (!(c->type&cJSON_IsReference) && c->child) { cJSON_Delete(c->child); c->child = 0; } if (!(c->type&cJSON_IsReference) && c->valuestring) { cJSON_free(c->valuestring); c->valuestring = 0; } if (c->string) { cJSON_free(c->string); c->string = 0; } cJSON_free(c); c=next; } } /* Parse the input text to generate a number, and populate the result into item. */ static const char *parse_number(cJSON *item,const char *num) { double n=0,sign=1,scale=0;int subscale=0,signsubscale=1; if (*num=='-') sign=-1,num++; /* Has sign? */ if (*num=='0') num++; /* is zero */ if (*num>='1' && *num<='9') do n=(n*10.0)+(*num++ -'0'); while (*num>='0' && *num<='9'); /* Number? */ if (*num=='.' && num[1]>='0' && num[1]<='9') {num++; do n=(n*10.0)+(*num++ -'0'),scale--; while (*num>='0' && *num<='9');} /* Fractional part? */ if (*num=='e' || *num=='E') /* Exponent? */ { num++;if (*num=='+') num++; else if (*num=='-') signsubscale=-1,num++; /* With sign? */ while (*num>='0' && *num<='9') subscale=(subscale*10)+(*num++ - '0'); /* Number? */ } n=sign*n*pow(10.0,(scale+subscale*signsubscale)); /* number = +/- number.fraction * 10^+/- exponent */ item->valuedouble=n; item->valueint=(long long)n; item->type=cJSON_Number; return num; } /* Render the number nicely from the given item into a string. */ static char *print_number(cJSON *item) { char *str; double d=item->valuedouble; if (fabs(((double)item->valueint)-d)<=DBL_EPSILON && d<=LLONG_MAX && d>=LLONG_MIN) { str=(char*)cJSON_malloc(21); /* 2^64+1 can be represented in 21 chars. */ if (str) sprintf(str,"%lld",item->valueint); } else { const size_t buffSize = 64; str=(char*)cJSON_malloc(buffSize); /* This is a nice tradeoff. */ if (str) { if (fabs(floor(d)-d)<=DBL_EPSILON && fabs(d)<1.0e60)snprintf(str, buffSize, "%.0f",d); else if (fabs(d)<1.0e-6 || fabs(d)>1.0e9) snprintf(str, buffSize, "%e",d); else snprintf(str, buffSize, "%f",d); } } return str; } static unsigned parse_hex4(const char *str) { unsigned h=0; if (*str>='0' && *str<='9') h+=(*str)-'0'; else if (*str>='A' && *str<='F') h+=10+(*str)-'A'; else if (*str>='a' && *str<='f') h+=10+(*str)-'a'; else return 0; h=h<<4;str++; if (*str>='0' && *str<='9') h+=(*str)-'0'; else if (*str>='A' && *str<='F') h+=10+(*str)-'A'; else if (*str>='a' && *str<='f') h+=10+(*str)-'a'; else return 0; h=h<<4;str++; if (*str>='0' && *str<='9') h+=(*str)-'0'; else if (*str>='A' && *str<='F') h+=10+(*str)-'A'; else if (*str>='a' && *str<='f') h+=10+(*str)-'a'; else return 0; h=h<<4;str++; if (*str>='0' && *str<='9') h+=(*str)-'0'; else if (*str>='A' && *str<='F') h+=10+(*str)-'A'; else if (*str>='a' && *str<='f') h+=10+(*str)-'a'; else return 0; return h; } /* Parse the input text into an unescaped cstring, and populate item. */ static const unsigned char firstByteMark[7] = { 0x00, 0x00, 0xC0, 0xE0, 0xF0, 0xF8, 0xFC }; static const char *parse_string(cJSON *item,const char *str) { const char *ptr=str+1;char *ptr2;char *out;int len=0;unsigned uc,uc2; if (*str!='\"') {ep=str;return 0;} /* not a string! */ while (*ptr!='\"' && *ptr && ++len) if (*ptr++ == '\\') ptr++; /* Skip escaped quotes. */ out=(char*)cJSON_malloc(len+1); /* This is how long we need for the string, roughly. */ if (!out) return 0; ptr=str+1;ptr2=out; while (*ptr!='\"' && *ptr) { if (*ptr!='\\') *ptr2++=*ptr++; else { ptr++; switch (*ptr) { case 'b': *ptr2++='\b'; break; case 'f': *ptr2++='\f'; break; case 'n': *ptr2++='\n'; break; case 'r': *ptr2++='\r'; break; case 't': *ptr2++='\t'; break; case 'u': /* transcode utf16 to utf8. */ uc=parse_hex4(ptr+1);ptr+=4; /* get the unicode char. */ if ((uc>=0xDC00 && uc<=0xDFFF) || uc==0) break; /* check for invalid. */ if (uc>=0xD800 && uc<=0xDBFF) /* UTF16 surrogate pairs. */ { if (ptr[1]!='\\' || ptr[2]!='u') break; /* missing second-half of surrogate. */ uc2=parse_hex4(ptr+3);ptr+=6; if (uc2<0xDC00 || uc2>0xDFFF) break; /* invalid second-half of surrogate. */ uc=0x10000 + (((uc&0x3FF)<<10) | (uc2&0x3FF)); } len=4;if (uc<0x80) len=1;else if (uc<0x800) len=2;else if (uc<0x10000) len=3; ptr2+=len; switch (len) { case 4: *--ptr2 =((uc | 0x80) & 0xBF); uc >>= 6; case 3: *--ptr2 =((uc | 0x80) & 0xBF); uc >>= 6; case 2: *--ptr2 =((uc | 0x80) & 0xBF); uc >>= 6; case 1: *--ptr2 =(uc | firstByteMark[len]); } ptr2+=len; break; default: *ptr2++=*ptr; break; } ptr++; } } *ptr2=0; if (*ptr=='\"') ptr++; item->valuestring=out; item->type=cJSON_String; return ptr; } /* Render the cstring provided to an escaped version that can be printed. */ static char *print_string_ptr(const char *str) { const char *ptr;char *ptr2,*out;int len=0;unsigned char token; if (!str) return cJSON_strdup(""); ptr=str;while ((token=*ptr) && ++len) {if (strchr("\"\\\b\f\n\r\t",token)) len++; else if (token<32) len+=5;ptr++;} out=(char*)cJSON_malloc(len+3); if (!out) return 0; ptr2=out;ptr=str; *ptr2++='\"'; while (*ptr) { if ((unsigned char)*ptr>31 && *ptr!='\"' && *ptr!='\\') *ptr2++=*ptr++; else { *ptr2++='\\'; switch (token=*ptr++) { case '\\': *ptr2++='\\'; break; case '\"': *ptr2++='\"'; break; case '\b': *ptr2++='b'; break; case '\f': *ptr2++='f'; break; case '\n': *ptr2++='n'; break; case '\r': *ptr2++='r'; break; case '\t': *ptr2++='t'; break; default: sprintf(ptr2,"u%04x",token);ptr2+=5; break; /* escape and print */ } } } *ptr2++='\"';*ptr2++=0; return out; } /* Invote print_string_ptr (which is useful) on an item. */ static char *print_string(cJSON *item) {return print_string_ptr(item->valuestring);} /* Predeclare these prototypes. */ static const char *parse_value(cJSON *item,const char *value); static char *print_value(cJSON *item,int depth,int fmt); static const char *parse_array(cJSON *item,const char *value); static char *print_array(cJSON *item,int depth,int fmt); static const char *parse_object(cJSON *item,const char *value); static char *print_object(cJSON *item,int depth,int fmt); /* Utility to jump whitespace and cr/lf */ static const char *skip(const char *in) {while (in && *in && (unsigned char)*in<=32) in++; return in;} /* Parse an object - create a new root, and populate. */ cJSON *cJSON_ParseWithOpts(const char *value,const char **return_parse_end,int require_null_terminated) { const char *end=0; cJSON *c=cJSON_New_Item(); ep=0; if (!c) return 0; /* memory fail */ end=parse_value(c,skip(value)); if (!end) {cJSON_Delete(c);return 0;} /* parse failure. ep is set. */ /* if we require null-terminated JSON without appended garbage, skip and then check for a null terminator */ if (require_null_terminated) {end=skip(end);if (*end) {cJSON_Delete(c);ep=end;return 0;}} if (return_parse_end) *return_parse_end=end; return c; } /* Default options for cJSON_Parse */ cJSON *cJSON_Parse(const char *value) {return cJSON_ParseWithOpts(value,0,0);} /* Render a cJSON item/entity/structure to text. */ char *cJSON_Print(cJSON *item) {return print_value(item,0,1);} char *cJSON_PrintUnformatted(cJSON *item) {return print_value(item,0,0);} /* Parser core - when encountering text, process appropriately. */ static const char *parse_value(cJSON *item,const char *value) { if (!value) return 0; /* Fail on null. */ if (!strncmp(value,"null",4)) { item->type=cJSON_NULL; return value+4; } if (!strncmp(value,"false",5)) { item->type=cJSON_False; return value+5; } if (!strncmp(value,"true",4)) { item->type=cJSON_True; item->valueint=1; return value+4; } if (*value=='\"') { return parse_string(item,value); } if (*value=='-' || (*value>='0' && *value<='9')) { return parse_number(item,value); } if (*value=='[') { return parse_array(item,value); } if (*value=='{') { return parse_object(item,value); } ep=value;return 0; /* failure. */ } /* Render a value to text. */ static char *print_value(cJSON *item,int depth,int fmt) { char *out=0; if (!item) return 0; switch ((item->type)&255) { case cJSON_NULL: out=cJSON_strdup("null"); break; case cJSON_False: out=cJSON_strdup("false");break; case cJSON_True: out=cJSON_strdup("true"); break; case cJSON_Number: out=print_number(item);break; case cJSON_String: out=print_string(item);break; case cJSON_Array: out=print_array(item,depth,fmt);break; case cJSON_Object: out=print_object(item,depth,fmt);break; } return out; } /* Build an array from input text. */ static const char *parse_array(cJSON *item,const char *value) { cJSON *child; if (*value!='[') {ep=value;return 0;} /* not an array! */ item->type=cJSON_Array; value=skip(value+1); if (*value==']') return value+1; /* empty array. */ item->child=child=cJSON_New_Item(); if (!item->child) return 0; /* memory fail */ value=skip(parse_value(child,skip(value))); /* skip any spacing, get the value. */ if (!value) return 0; while (*value==',') { cJSON *new_item; if (!(new_item=cJSON_New_Item())) return 0; /* memory fail */ child->next=new_item;new_item->prev=child;child=new_item; value=skip(parse_value(child,skip(value+1))); if (!value) return 0; /* memory fail */ } if (*value==']') return value+1; /* end of array */ ep=value;return 0; /* malformed. */ } /* Render an array to text */ static char *print_array(cJSON *item,int depth,int fmt) { char **entries; char *out=0,*ptr,*ret;int len=5; cJSON *child=item->child; int numentries=0,i=0,fail=0; /* How many entries in the array? */ while (child) numentries++,child=child->next; /* Explicitly handle numentries==0 */ if (!numentries) { out=(char*)cJSON_malloc(3); if (out) strcpy(out,"[]"); return out; } /* Allocate an array to hold the values for each */ entries=(char**)cJSON_malloc(numentries*sizeof(char*)); if (!entries) return 0; memset(entries,0,numentries*sizeof(char*)); /* Retrieve all the results: */ child=item->child; while (child && !fail) { ret=print_value(child,depth+1,fmt); entries[i++]=ret; if (ret) len+=strlen(ret)+2+(fmt?1:0); else fail=1; child=child->next; } /* If we didn't fail, try to malloc the output string */ if (!fail) out=(char*)cJSON_malloc(len); /* If that fails, we fail. */ if (!out) fail=1; /* Handle failure. */ if (fail) { for (i=0;itype=cJSON_Object; value=skip(value+1); if (*value=='}') return value+1; /* empty array. */ item->child=child=cJSON_New_Item(); if (!item->child) return 0; value=skip(parse_string(child,skip(value))); if (!value) return 0; child->string=child->valuestring;child->valuestring=0; if (*value!=':') {ep=value;return 0;} /* fail! */ value=skip(parse_value(child,skip(value+1))); /* skip any spacing, get the value. */ if (!value) return 0; while (*value==',') { cJSON *new_item; if (!(new_item=cJSON_New_Item())) return 0; /* memory fail */ child->next=new_item;new_item->prev=child;child=new_item; value=skip(parse_string(child,skip(value+1))); if (!value) return 0; child->string=child->valuestring;child->valuestring=0; if (*value!=':') {ep=value;return 0;} /* fail! */ value=skip(parse_value(child,skip(value+1))); /* skip any spacing, get the value. */ if (!value) return 0; } if (*value=='}') return value+1; /* end of array */ ep=value;return 0; /* malformed. */ } /* Render an object to text. */ static char *print_object(cJSON *item,int depth,int fmt) { char **entries=0,**names=0; char *out=0,*ptr,*ret,*str;int len=7,i=0,j; cJSON *child=item->child; int numentries=0,fail=0; /* Count the number of entries. */ while (child) numentries++,child=child->next; /* Explicitly handle empty object case */ if (!numentries) { out=(char*)cJSON_malloc(fmt?depth+4:3); if (!out) return 0; ptr=out;*ptr++='{'; if (fmt) {*ptr++='\n';for (i=0;ichild;depth++;if (fmt) len+=depth; while (child) { names[i]=str=print_string_ptr(child->string); entries[i++]=ret=print_value(child,depth,fmt); if (str && ret) len+=strlen(ret)+strlen(str)+2+(fmt?2+depth:0); else fail=1; child=child->next; } /* Try to allocate the output string */ if (!fail) out=(char*)cJSON_malloc(len); if (!out) fail=1; /* Handle failure */ if (fail) { for (i=0;ichild;int i=0;while(c)i++,c=c->next;return i;} cJSON *cJSON_GetArrayItem(cJSON *array,int item) {cJSON *c=array->child; while (c && item>0) item--,c=c->next; return c;} cJSON *cJSON_GetObjectItem(cJSON *object,const char *string) {cJSON *c=object->child; while (c && cJSON_strcasecmp(c->string,string)) c=c->next; return c;} /* Utility for array list handling. */ static void suffix_object(cJSON *prev,cJSON *item) {prev->next=item;item->prev=prev;} /* Utility for handling references. */ static cJSON *create_reference(cJSON *item) {cJSON *ref=cJSON_New_Item();if (!ref) return 0;memcpy(ref,item,sizeof(cJSON));ref->string=0;ref->type|=cJSON_IsReference;ref->next=ref->prev=0;return ref;} /* Add item to array/object. */ void cJSON_AddItemToArray(cJSON *array, cJSON *item) {cJSON *c=array->child;if (!item) return; if (!c) {array->child=item;} else {while (c && c->next) c=c->next; suffix_object(c,item);}} void cJSON_AddItemToObject(cJSON *object,const char *string,cJSON *item) {if (!item) return; if (item->string) cJSON_free(item->string);item->string=cJSON_strdup(string);cJSON_AddItemToArray(object,item);} void cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item) {cJSON_AddItemToArray(array,create_reference(item));} void cJSON_AddItemReferenceToObject(cJSON *object,const char *string,cJSON *item) {cJSON_AddItemToObject(object,string,create_reference(item));} cJSON *cJSON_DetachItemFromArray(cJSON *array,int which) {cJSON *c=array->child;while (c && which>0) c=c->next,which--;if (!c) return 0; if (c->prev) c->prev->next=c->next;if (c->next) c->next->prev=c->prev;if (c==array->child) array->child=c->next;c->prev=c->next=0;return c;} void cJSON_DeleteItemFromArray(cJSON *array,int which) {cJSON_Delete(cJSON_DetachItemFromArray(array,which));} cJSON *cJSON_DetachItemFromObject(cJSON *object,const char *string) {int i=0;cJSON *c=object->child;while (c && cJSON_strcasecmp(c->string,string)) i++,c=c->next;if (c) return cJSON_DetachItemFromArray(object,i);return 0;} void cJSON_DeleteItemFromObject(cJSON *object,const char *string) {cJSON_Delete(cJSON_DetachItemFromObject(object,string));} /* Replace array/object items with new ones. */ void cJSON_ReplaceItemInArray(cJSON *array,int which,cJSON *newitem) {cJSON *c=array->child;while (c && which>0) c=c->next,which--;if (!c) return; newitem->next=c->next;newitem->prev=c->prev;if (newitem->next) newitem->next->prev=newitem; if (c==array->child) array->child=newitem; else newitem->prev->next=newitem;c->next=c->prev=0;cJSON_Delete(c);} void cJSON_ReplaceItemInObject(cJSON *object,const char *string,cJSON *newitem){int i=0;cJSON *c=object->child;while(c && cJSON_strcasecmp(c->string,string))i++,c=c->next;if(c){newitem->string=cJSON_strdup(string);cJSON_ReplaceItemInArray(object,i,newitem);}} /* Create basic types: */ cJSON *cJSON_CreateNull(void) {cJSON *item=cJSON_New_Item();if(item)item->type=cJSON_NULL;return item;} cJSON *cJSON_CreateTrue(void) {cJSON *item=cJSON_New_Item();if(item)item->type=cJSON_True;return item;} cJSON *cJSON_CreateFalse(void) {cJSON *item=cJSON_New_Item();if(item)item->type=cJSON_False;return item;} cJSON *cJSON_CreateBool(int b) {cJSON *item=cJSON_New_Item();if(item)item->type=(b?cJSON_True:cJSON_False);return item;} cJSON *cJSON_CreateNumber(double num) {cJSON *item=cJSON_New_Item();if(item){item->type=cJSON_Number;item->valuedouble=num;item->valueint=(long long)num;}return item;} cJSON *cJSON_CreateString(const char *string) {cJSON *item=cJSON_New_Item();if(item){item->type=cJSON_String;item->valuestring=cJSON_strdup(string);}return item;} cJSON *cJSON_CreateArray(void) {cJSON *item=cJSON_New_Item();if(item)item->type=cJSON_Array;return item;} cJSON *cJSON_CreateObject(void) {cJSON *item=cJSON_New_Item();if(item)item->type=cJSON_Object;return item;} /* Create Arrays: */ cJSON *cJSON_CreateIntArray(const int *numbers,int count) {int i;cJSON *n=0,*p=0,*a=cJSON_CreateArray();for(i=0;a && ichild=n;else suffix_object(p,n);p=n;}return a;} cJSON *cJSON_CreateFloatArray(const float *numbers,int count) {int i;cJSON *n=0,*p=0,*a=cJSON_CreateArray();for(i=0;a && ichild=n;else suffix_object(p,n);p=n;}return a;} cJSON *cJSON_CreateDoubleArray(const double *numbers,int count) {int i;cJSON *n=0,*p=0,*a=cJSON_CreateArray();for(i=0;a && ichild=n;else suffix_object(p,n);p=n;}return a;} cJSON *cJSON_CreateStringArray(const char **strings,int count) {int i;cJSON *n=0,*p=0,*a=cJSON_CreateArray();for(i=0;a && ichild=n;else suffix_object(p,n);p=n;}return a;} /* Duplication */ cJSON *cJSON_Duplicate(cJSON *item,int recurse) { cJSON *newitem,*cptr,*nptr=0,*newchild; /* Bail on bad ptr */ if (!item) return 0; /* Create new item */ newitem=cJSON_New_Item(); if (!newitem) return 0; /* Copy over all vars */ newitem->type=item->type&(~cJSON_IsReference),newitem->valueint=item->valueint,newitem->valuedouble=item->valuedouble; if (item->valuestring) {newitem->valuestring=cJSON_strdup(item->valuestring); if (!newitem->valuestring) {cJSON_Delete(newitem);return 0;}} if (item->string) {newitem->string=cJSON_strdup(item->string); if (!newitem->string) {cJSON_Delete(newitem);return 0;}} /* If non-recursive, then we're done! */ if (!recurse) return newitem; /* Walk the ->next chain for the child. */ cptr=item->child; while (cptr) { newchild=cJSON_Duplicate(cptr,1); /* Duplicate (with recurse) each item in the ->next chain */ if (!newchild) {cJSON_Delete(newitem);return 0;} if (nptr) {nptr->next=newchild,newchild->prev=nptr;nptr=newchild;} /* If newitem->child already set, then crosswire ->prev and ->next and move on */ else {newitem->child=newchild;nptr=newchild;} /* Set newitem->child and move to it */ cptr=cptr->next; } return newitem; } void cJSON_Minify(char *json) { char *into=json; while (*json) { if (*json==' ') json++; else if (*json=='\t') json++; // Whitespace characters. else if (*json=='\r') json++; else if (*json=='\n') json++; else if (*json=='/' && json[1]=='/') while (*json && *json!='\n') json++; // double-slash comments, to end of line. else if (*json=='/' && json[1]=='*') {while (*json && !(*json=='*' && json[1]=='/')) json++;json+=2;} // multiline comments. else if (*json=='\"'){*into++=*json++;while (*json && *json!='\"'){if (*json=='\\') *into++=*json++;*into++=*json++;}*into++=*json++;} // string literals, which are \" sensitive. else *into++=*json++; // All other characters. } *into=0; // and null-terminate. } ================================================ FILE: Diagnostic/mdsd/mdsd/cJSON.h ================================================ /* Copyright (c) 2009 Dave Gamble Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef cJSON__h #define cJSON__h #ifdef __cplusplus extern "C" { #endif /* cJSON Types: */ #define cJSON_False 0 #define cJSON_True 1 #define cJSON_NULL 2 #define cJSON_Number 3 #define cJSON_String 4 #define cJSON_Array 5 #define cJSON_Object 6 #define cJSON_IsReference 256 /* The cJSON structure: */ typedef struct cJSON { struct cJSON *next,*prev; /* next/prev allow you to walk array/object chains. Alternatively, use GetArraySize/GetArrayItem/GetObjectItem */ struct cJSON *child; /* An array or object item will have a child pointer pointing to a chain of the items in the array/object. */ int type; /* The type of the item, as above. */ char *valuestring; /* The item's string, if type==cJSON_String */ long long valueint; /* The item's number, if type==cJSON_Number */ double valuedouble; /* The item's number, if type==cJSON_Number */ char *string; /* The item's name string, if this item is the child of, or is in the list of subitems of an object. */ } cJSON; typedef struct cJSON_Hooks { void *(*malloc_fn)(size_t sz); void (*free_fn)(void *ptr); } cJSON_Hooks; /* Supply malloc, realloc and free functions to cJSON */ extern void cJSON_InitHooks(cJSON_Hooks* hooks); /* Supply a block of JSON, and this returns a cJSON object you can interrogate. Call cJSON_Delete when finished. */ extern cJSON *cJSON_Parse(const char *value); /* Render a cJSON entity to text for transfer/storage. Free the char* when finished. */ extern char *cJSON_Print(cJSON *item); /* Render a cJSON entity to text for transfer/storage without any formatting. Free the char* when finished. */ extern char *cJSON_PrintUnformatted(cJSON *item); /* Delete a cJSON entity and all subentities. */ extern void cJSON_Delete(cJSON *c); /* Returns the number of items in an array (or object). */ extern int cJSON_GetArraySize(cJSON *array); /* Retrieve item number "item" from array "array". Returns NULL if unsuccessful. */ extern cJSON *cJSON_GetArrayItem(cJSON *array,int item); /* Get item "string" from object. Case insensitive. */ extern cJSON *cJSON_GetObjectItem(cJSON *object,const char *string); /* For analysing failed parses. This returns a pointer to the parse error. You'll probably need to look a few chars back to make sense of it. Defined when cJSON_Parse() returns 0. 0 when cJSON_Parse() succeeds. */ extern const char *cJSON_GetErrorPtr(void); /* These calls create a cJSON item of the appropriate type. */ extern cJSON *cJSON_CreateNull(void); extern cJSON *cJSON_CreateTrue(void); extern cJSON *cJSON_CreateFalse(void); extern cJSON *cJSON_CreateBool(int b); extern cJSON *cJSON_CreateNumber(double num); extern cJSON *cJSON_CreateString(const char *string); extern cJSON *cJSON_CreateArray(void); extern cJSON *cJSON_CreateObject(void); /* These utilities create an Array of count items. */ extern cJSON *cJSON_CreateIntArray(const int *numbers,int count); extern cJSON *cJSON_CreateFloatArray(const float *numbers,int count); extern cJSON *cJSON_CreateDoubleArray(const double *numbers,int count); extern cJSON *cJSON_CreateStringArray(const char **strings,int count); /* Append item to the specified array/object. */ extern void cJSON_AddItemToArray(cJSON *array, cJSON *item); extern void cJSON_AddItemToObject(cJSON *object,const char *string,cJSON *item); /* Append reference to item to the specified array/object. Use this when you want to add an existing cJSON to a new cJSON, but don't want to corrupt your existing cJSON. */ extern void cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item); extern void cJSON_AddItemReferenceToObject(cJSON *object,const char *string,cJSON *item); /* Remove/Detatch items from Arrays/Objects. */ extern cJSON *cJSON_DetachItemFromArray(cJSON *array,int which); extern void cJSON_DeleteItemFromArray(cJSON *array,int which); extern cJSON *cJSON_DetachItemFromObject(cJSON *object,const char *string); extern void cJSON_DeleteItemFromObject(cJSON *object,const char *string); /* Update array items. */ extern void cJSON_ReplaceItemInArray(cJSON *array,int which,cJSON *newitem); extern void cJSON_ReplaceItemInObject(cJSON *object,const char *string,cJSON *newitem); /* Duplicate a cJSON item */ extern cJSON *cJSON_Duplicate(cJSON *item,int recurse); /* Duplicate will create a new, identical cJSON item to the one you pass, in new memory that will need to be released. With recurse!=0, it will duplicate any children connected to the item. The item->next and ->prev pointers are always zero on return from Duplicate. */ /* ParseWithOpts allows you to require (and check) that the JSON is null terminated, and to retrieve the pointer to the final byte parsed. */ extern cJSON *cJSON_ParseWithOpts(const char *value,const char **return_parse_end,int require_null_terminated); extern void cJSON_Minify(char *json); /* Macros for creating things quickly. */ #define cJSON_AddNullToObject(object,name) cJSON_AddItemToObject(object, name, cJSON_CreateNull()) #define cJSON_AddTrueToObject(object,name) cJSON_AddItemToObject(object, name, cJSON_CreateTrue()) #define cJSON_AddFalseToObject(object,name) cJSON_AddItemToObject(object, name, cJSON_CreateFalse()) #define cJSON_AddBoolToObject(object,name,b) cJSON_AddItemToObject(object, name, cJSON_CreateBool(b)) #define cJSON_AddNumberToObject(object,name,n) cJSON_AddItemToObject(object, name, cJSON_CreateNumber(n)) #define cJSON_AddStringToObject(object,name,s) cJSON_AddItemToObject(object, name, cJSON_CreateString(s)) /* When assigning an integer value, it needs to be propagated to valuedouble too. */ #define cJSON_SetIntValue(object,val) ((object)?(object)->valueint=(object)->valuedouble=(val):(val)) #ifdef __cplusplus } #endif #endif ================================================ FILE: Diagnostic/mdsd/mdsd/cryptutil.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "cryptutil.hh" #include #include #include #include #include #include extern "C" { #include #include #include #include #include #include #include } using namespace std; namespace cryptutil { using uniqueEvpKey = std::unique_ptr; using uniqueCms = std::unique_ptr; using uniqueP12 = std::unique_ptr; // True if file exists, false if not bool FileExists(const string& filename) { struct stat buffer; return ((stat(filename.c_str(), &buffer)==0) && S_ISREG(buffer.st_mode)); } // Convert a hex string into a vector of bytes bool DecodeString(const string& encoded, vector& byteBuf) { if (encoded.length() < 2) { return false; } auto bufLen = encoded.length() / 2; byteBuf = vector(bufLen); size_t idx = 0; for (size_t i = 0; i < bufLen; i++) { BYTE data1 = (BYTE)(encoded[idx] - '0'); if (data1 > 9) { data1 = (BYTE)((encoded[idx] - 'A') + 10); } BYTE data2 = (BYTE)(encoded[idx+1] - '0'); if (data2 > 9) { data2 = (BYTE)((encoded[idx+1] - 'A') + 10); } byteBuf[i] = (data1 << 4) | data2; idx += 2; } return true; } // Read a string from the data in a BIO object string GetStringFromBio(BIO *mem) { if (mem == nullptr) { throw invalid_argument("A nullptr was passed in place of a BIO argument"); } const int bufSize = 10; char buf[bufSize] = ""; stringstream ss; while (BIO_gets(mem, buf, bufSize) > 0) { ss << buf; } return ss.str(); } // Open a PKCS12 (.pfx) file, and return a suitable object or throw an exception uniqueP12 GetPkcs12FromFile(const string& privKeyPath) { FILE *p12_file = fopen(privKeyPath.c_str(), "rb"); if (p12_file == nullptr) { throw system_error(errno, system_category(), string("Unable to read PKCS12 file " + privKeyPath)); } PKCS12 *p12 = nullptr; d2i_PKCS12_fp(p12_file, &p12); fclose(p12_file); if (p12 == nullptr) { throw cryptutilException("PKCS12 structure could not be parsed from " + privKeyPath); } uniqueP12 retP12(p12, PKCS12_free); return retP12; } // Return the EVP_PKEY contained in the specified pkcs12 file, or throw an exception uniqueEvpKey GetPrivateKeyFromPkcs12(const string& privKeyPath, const string& keyPass) { EVP_PKEY *pkey = nullptr; X509 *cert = nullptr; uniqueP12 p12 = GetPkcs12FromFile(privKeyPath); if (!PKCS12_parse(p12.get(), keyPass.c_str(), &pkey, &cert, (STACK_OF(X509)**)nullptr)) { throw cryptutilException("Could not parse private key from PKCS12 file " + privKeyPath); } uniqueEvpKey retKey(pkey, EVP_PKEY_free); // clear certs X509_free(cert); return retKey; } // Return the EVP_PKEY contained in the specified PEM file, or NULL if a failure occurs. uniqueEvpKey GetPrivateKeyFromPem(const string& privKeyPath) { BIO *keyBio = BIO_new_file(privKeyPath.c_str(), "r"); if (keyBio == nullptr) { throw cryptutilException("Unable to read PEM file " + privKeyPath); } EVP_PKEY *pkey = PEM_read_bio_PrivateKey(keyBio, NULL, 0, NULL); BIO_free(keyBio); if (pkey == nullptr) { throw cryptutilException("Unable to parse private key from PEM file " + privKeyPath); } uniqueEvpKey retKey(pkey, EVP_PKEY_free); return retKey; } // Try to parse the specified file as PKCS12 (PFX) or PEM, return the private key or NULL uniqueEvpKey GetPrivateKeyFromUnknownFileType(const string& privKeyPath, const string& keyPass) { try { return GetPrivateKeyFromPem(privKeyPath); } catch (exception& ex) { // File isn't a PEM, but it might be a PFX. We don't care unless BOTH fail. } // This function can throw cryptutilException and system_error. // No need to catch/rethrow the exception - just let it go unhindered // The last call in this function should always allow any exceptions // to pass through to the caller. return GetPrivateKeyFromPkcs12(privKeyPath, keyPass); } // Parse the specified file as Cryptographic Message Syntax (CMS) or return NULL uniqueCms GetCMSFromEncodedString(const string& encoded) { // Decode text from hex chars to binary vector byteBuf; if(!DecodeString(encoded, byteBuf)) { throw cryptutilException("Unable to decode provided string to CMS"); } BIO *mem = BIO_new_mem_buf(byteBuf.data(), byteBuf.size()); // Read encrypted text CMS_ContentInfo *cms = d2i_CMS_bio(mem, NULL); BIO_free(mem); if (cms == nullptr) { throw cryptutilException("Unable to parse CMS from decoded string"); } uniqueCms retCms(cms, CMS_ContentInfo_free); return retCms; } // Given a private key and CMS object,return decrypted string // or throw an exception string DecryptCMSWithPrivateKey(uniqueEvpKey& pkey, uniqueCms& cms) { if (pkey.get() == nullptr) { throw invalid_argument("The provided private key must not be a nullptr"); } if (cms.get() == nullptr) { throw invalid_argument("The provided CMS must not be a nullptr"); } // Decrypt file contents BIO *out = BIO_new(BIO_s_mem()); int res = CMS_decrypt(cms.get(), pkey.get(), NULL, NULL, out, 0); if (!res) { BIO_free(out); int error = ERR_get_error(); const char* errstr = ERR_reason_error_string(error); if (errstr) { throw cryptutilException("Error decrypting cipher text [" + string(errstr) + "]"); } else { throw cryptutilException("Error decrypting cipher text"); } } string plaintext = GetStringFromBio(out); BIO_free(out); return plaintext; } // Given an encrypted STRING (CMS encoded as hex chars), a private key file, and an optional password, // decode and decrypt the CMS and return the decrypted string, or throw a cryptutilException if it fails string DecodeAndDecryptString(const string& privKeyPath, const string& encoded, const string& keyPass) { if (privKeyPath.empty()) { throw invalid_argument("The private key path must not be an empty string"); } if (encoded.empty()) { throw invalid_argument("The encoded ciphertext must not be an empty string"); } if (!FileExists(privKeyPath)) { throw runtime_error("Private key file was not found at path: " + privKeyPath); } OpenSSL_add_all_algorithms(); ERR_load_crypto_strings(); // Read Private Key uniqueEvpKey pkey = GetPrivateKeyFromUnknownFileType(privKeyPath, keyPass); uniqueCms cms = GetCMSFromEncodedString(encoded); return DecryptCMSWithPrivateKey(pkey, cms); } } ================================================ FILE: Diagnostic/mdsd/mdsd/cryptutil.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #ifndef _CRYPTUTIL_H_ #define _CRYPTUTIL_H_ #include #include #include typedef unsigned char BYTE; namespace cryptutil { bool DecodeString(const std::string& encodedString, std::vector& results); std::string DecodeAndDecryptString(const std::string& privKeyPath, const std::string& encodedString, const std::string& keyPassword = ""); // Custom exception class class cryptutilException : public std::exception { std::string exMessage; public: cryptutilException(const std::string& errDetail) : exMessage(errDetail) {} virtual const char* what() const throw() { return exMessage.c_str(); } }; } #endif ================================================ FILE: Diagnostic/mdsd/mdsd/fdelt_chk.c ================================================ /* Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. */ #include # define strong_alias(name, aliasname) _strong_alias(name, aliasname) # define _strong_alias(name, aliasname) \ extern __typeof (name) aliasname __attribute__ ((alias (#name))); /* * 'unsigned' dropped from the original source, to match * the prototype defined in select2.h. */ long int __fdelt_chk (long int d) { if (d >= FD_SETSIZE) __chk_fail (); return d / __NFDBITS; } strong_alias (__fdelt_chk, __fdelt_warn) ================================================ FILE: Diagnostic/mdsd/mdsd/mdsautokey.h ================================================ // -------------------------------------------------------------------------------------------------------------------- // // Copyright (c) Microsoft Corporation. All rights reserved. // // -------------------------------------------------------------------------------------------------------------------- // The autokey feature is never used by the Linux Diagnostic Extension; this stub disables the feature. #ifndef _AUTOKEY_H_ #define _AUTOKEY_H_ #include #include namespace mdsautokey { enum autokeyResultStatus { autokeySuccess, autokeyPartialSuccess, autokeyFailure }; class autokeyResult { public: autokeyResultStatus status; autokeyResult(autokeyResultStatus stat) : status(stat) {} autokeyResult() : status(autokeyResultStatus::autokeySuccess) {} }; autokeyResult GetLatestMdsKeys(const std::string& autokeyCfg, const std::string& nmspace, int eventVersion, std::map, std::string>& keys) { return autokeyResult(autokeyResultStatus::autokeyFailure); } } #endif ================================================ FILE: Diagnostic/mdsd/mdsd/mdsd.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "Logger.hh" #include "ProtocolListenerMgr.hh" #include "MdsdConfig.hh" #include "LocalSink.hh" #include "Engine.hh" #include "Version.hh" #include "Trace.hh" #include "DaemonConf.hh" #include "ExtensionMgmt.hh" #include "Utility.hh" #include "HttpProxySetup.hh" #include "EventHubUploaderMgr.hh" #include "XJsonBlobBlockCountsMgr.hh" #include #include #include #include #include #include #include #include #include #include #include extern "C" { #include #include #include #include #include #include #include } using std::string; using std::to_string; using std::cerr; using std::endl; void usage(); extern "C" { void SetSignalCatchers(int); } void TerminateHandler(); // This is a file-scope string static std::string config_file_path; static std::string autokey_config_path; int main(int argc, char **argv) { int mdsd_port = 29130; // Default port number, grabbed out of the air Engine* engine = Engine::GetEngine(); bool mdsdConfigValidationOnly = false; bool runAsDaemon = false; // If true, run at Daemon mode instead of application mode. bool coreDumpAtFatal = false; // If true, create core dump when received fatal signals. std::string proxy_setting_string; // E.g., "[http:]//[username:password@]www.xyz.com:8080/" bool disableLogging = false; // Useful for development testing bool retryRandomPort = false; Logger::Init(); std::string mdsd_config_dir; std::string mdsd_run_dir; std::string mdsd_log_dir; try { mdsd_config_dir = MdsdUtil::GetEnvDirVar("MDSD_CONFIG_DIR", "/etc/mdsd.d"); mdsd_run_dir = MdsdUtil::GetEnvDirVar("MDSD_RUN_DIR", "/var/run/mdsd"); mdsd_log_dir = MdsdUtil::GetEnvDirVar("MDSD_LOG_DIR", "/var/log"); } catch (std::runtime_error& ex) { Logger::LogError(ex.what()); exit(1); } config_file_path = mdsd_config_dir + "/mdsd.xml"; autokey_config_path = mdsd_config_dir + "/mdsautokey.cfg"; const std::string config_cache_dir = mdsd_config_dir + "/config-cache"; std::string mdsd_prefix = mdsd_run_dir + "/"; std::string mdsd_role = "default"; // altered by '-r' std::string mdsd_role_prefix = mdsd_prefix + mdsd_role; // replaced with '-r' value if it starts with '/' std::string ehSaveDir = mdsd_run_dir + "/eh"; // Full path to save failed Event Hub events. // default mdsd log file paths, they can be overwritten by input args. std::string mdsdInfoFile = mdsd_log_dir + "/mdsd.info"; std::string mdsdWarnFile = mdsd_log_dir + "/mdsd.warn"; std::string mdsdErrFile = mdsd_log_dir + "/mdsd.err"; LocalSink::Initialize(); { int opt; while ((opt = getopt(argc, argv, "bc:CDde:jo:P:p:Rr:S:T:vVw:")) != -1) { switch (opt) { case 'b': engine->BlackholeEvents(); break; case 'c': config_file_path = optarg; break; case 'C': coreDumpAtFatal = true; break; case 'D': disableLogging = true; break; case 'd': runAsDaemon = true; break; case 'e': mdsdErrFile = optarg; break; case 'j': Trace::AddInterests(Trace::EventIngest); break; case 'o': mdsdInfoFile = optarg; break; case 'P': proxy_setting_string = optarg; try { MdsdUtil::CheckProxySettingString(proxy_setting_string); } catch (const MdsdUtil::HttpProxySetupException& e) { cerr << "Invalid proxy specification for -P option: " << e.what() << endl; usage(); } break; case 'p': mdsd_port = atoi(optarg); if (mdsd_port < 0) { // We now allow '-p 0' (binding to a random port) usage(); } break; case 'R': retryRandomPort = true; break; case 'r': if (*optarg == '/') { // Special case to allow overriding of the default mdsd_prefix (e.g. /var/run/mdsd). // This may be needed in cases where mdsd will not be able to create or write to /var/run/mdsd. // This is useful during dev testing and might also be needed for LAD. mdsd_role_prefix = optarg; } else { mdsd_role_prefix = mdsd_prefix + std::string(optarg); } break; case 'S': ehSaveDir = optarg; if (ehSaveDir.empty()) { cerr << "'-S' requires a valid pathname." << endl; usage(); } break; case 'T': try { unsigned long val = std::stol(string(optarg), 0, 0); Trace::AddInterests(static_cast(val)); } catch (std::exception & ex) { usage(); } break; case 'v': mdsdConfigValidationOnly = true; break; case 'V': cerr << Version::Version << endl; exit(0); case 'w': mdsdWarnFile = optarg; break; default: /* '?' */ usage(); } } } // For config xml validation only, log to console. if (!mdsdConfigValidationOnly) { // Only try to create the mdsd_run_dir dir if it wasn't overridden via '-r' option. if (mdsd_role_prefix.substr(0, mdsd_run_dir.length()) == mdsd_run_dir) { try { MdsdUtil::CreateDirIfNotExists(mdsd_run_dir, 01755); } catch (std::exception &e) { Logger::LogError("Fatal error: unexpected exception at creating dir '" + mdsd_run_dir + "'. " + "Reason: " + e.what()); exit(1); } } try { MdsdUtil::CreateDirIfNotExists(ehSaveDir, 01755); } catch(std::exception & e) { Logger::LogError("Fatal error: unexpected exception at creating dir '" + ehSaveDir + "'. " + "Reason: " + e.what()); exit(1); } if (!disableLogging) { Logger::SetInfoLog(mdsdInfoFile.c_str()); Logger::SetWarnLog(mdsdWarnFile.c_str()); Logger::SetErrorLog(mdsdErrFile.c_str()); } if (0 == geteuid() && runAsDaemon) { // Change ownership of logs if we're running as root DaemonConf::Chown(mdsdInfoFile); DaemonConf::Chown(mdsdWarnFile); DaemonConf::Chown(mdsdErrFile); if (mdsd_role_prefix.substr(0, mdsd_run_dir.length()) == mdsd_run_dir) { DaemonConf::Chown(mdsd_run_dir); } DaemonConf::Chown(ehSaveDir); } } try { XJsonBlobBlockCountsMgr::GetInstance().SetPersistDir(mdsd_role_prefix + "_jsonblob_blkcts", mdsdConfigValidationOnly); } catch (std::exception& e) { Logger::LogError(std::string("Unexpected exception from setting JsonBlobBlockCountsMgr persist dir. Reason: ").append(e.what())); exit(1); } if (runAsDaemon) { DaemonConf::RunAsDaemon(mdsd_role_prefix + ".pid"); } SetSignalCatchers(coreDumpAtFatal); std::set_terminate(TerminateHandler); if (mdsdConfigValidationOnly) { std::unique_ptr newconfig(new MdsdConfig(config_file_path, autokey_config_path)); int status = 0; if (newconfig->GotMessages(MdsdConfig::anySeverity)) { cerr << "Parse reported these messages:" << endl; newconfig->MessagesToStream(cerr, MdsdConfig::anySeverity); status = 1; } else { cerr << "Parse succeeded with no messages." << endl; } newconfig.reset(); exit(status); } if (!mdsd::EventHubUploaderMgr::GetInstance().SetTopLevelPersistDir(ehSaveDir)) { exit(1); } ProtocolListenerMgr::Init(mdsd_role_prefix, mdsd_port, retryRandomPort); MdsdConfig* newconfig = new MdsdConfig(config_file_path, autokey_config_path); auto valid = newconfig->ValidateConfig(true); if (!valid || !newconfig->IsUseful()) { Logger::LogError("Error: Config invalid or not useful (if there's no config parse error). Abort mdsd."); delete newconfig; exit(1); } Engine::SetConfiguration(newconfig); try { MdsdUtil::SetStorageHttpProxy(proxy_setting_string, { "MDSD_http_proxy", "https_proxy", "http_proxy" }); } catch(const std::exception & ex) { Logger::LogError(ex.what()); exit(1); } ExtensionMgmt::StartExtensionsAsync(Engine::GetEngine()->GetConfig()); // Start the listeners auto plmgmt = ProtocolListenerMgr::GetProtocolListenerMgr(); try { if (!plmgmt->Start()) { Logger::LogError("One or more listeners failed to start."); exit(1); } } catch(std::exception& ex) { Logger::LogError("Error: unexpected exception while starting listeners: " + std::string(ex.what())); exit(1); } catch(...) { Logger::LogError("Error: unknown exception while starting listeners."); exit(1); } // Wait to be stopped plmgmt->Wait(); return 0; } void usage() { cerr << "Usage:" << endl << "mdsd [-Abdjv] [-c path] [-e path] [-o path] [-p port] [-P proxy_setting] [-r path] [-S path] [-T flags] [-w path]" << endl << endl << "-A Don't enable config auto management." << endl << "-b Don't forward events to MDS (blackhole them instead)" << endl << "-c Specifies the path to the configuration XML file" << endl << "-C Don't suppress core dump when dying due to fatal signals" << endl << "-D Disable logging to files. All log output will instead go to STDERR (fd 2)." << endl << "-d Run mdsd as a daemon" << endl << "-e Specifies the path to which mdsd error logs are dumped" << endl << "-j Dump all JSON events to stdout as they're received" << endl << "-o Specifies the path to which mdsd informative logs are dumped" << endl << "-p Specifies the port on which the daemon listens for stream connections (0 can be passed" << endl << " as port, in which case a randomly available port will be picked). The port will only be" << endl << " bound to 127.0.0.1 (loopback). If the specified non-zero port is in use," << endl << " and '-R' is specified, then mdsd will try to bind to a randomly available port instead." << endl << " Either way, the bound port number will be written to a file whose path is derived" << endl << " from -r info or default (/var/run/mdsd/default.pidport)." << endl << "-P Specifies an HTTP proxy. If not set, use environment variable in order of MDSD_http_proxy," << endl << " https_proxy, http_proxy, with first one tried first. If -P is set, override environment variables." << endl << "-R Try binding to a random port if binding to the default/specified port fails." << endl << "-r Specifies the role name or file prefix that mdsd will use to construct the paths to the" << endl << " pidport and unix domain socket files. If the argument starts with '/' then the value is" << endl << " used as the file prefix, otherwise it is used as the role name and the file prefix is " << endl << " '/var/run/mdsd/' + role name (e.g. if role name is 'test' then the prefix is '/var/run/mdsd/test')." << endl << "-S Specifies directory to save Event Hub events. syslog user needs to have rwx" << endl << " access to it. If the directory does not exist, mdsd will try to create it." << endl << "-T Enable tracing for modules selected by flags" << endl << "-v Validate configuration file and exit" << endl << "-V Print version and exit" << endl << "-w Specifies the path to which mdsd warning logs are dumped" << endl; exit(1); } extern "C" void LoadNewConfiguration() { Trace trace(Trace::ConfigLoad, "LoadNewConfiguration"); Logger::LogInfo("Reloading configuration (SIGHUP caught)"); MdsdConfig *newconfig = new MdsdConfig(config_file_path, autokey_config_path); bool valid = newconfig->ValidateConfig(true); if (!valid || !newconfig->IsUseful()) { delete newconfig; } else { Engine::SetConfiguration(newconfig); ExtensionMgmt::StartExtensionsAsync(newconfig); } } extern "C" void SetCoreDumpLimit() { Logger::LogInfo("Set resource limits for core dump."); struct rlimit core_limit; if (getrlimit(RLIMIT_CORE, &core_limit) < 0) { std::string errstr = MdsdUtil::GetErrnoStr(errno); Logger::LogError("Error: getrlimit failed. Reason: " + errstr); return; } if (RLIM_INFINITY != core_limit.rlim_cur) { core_limit.rlim_cur = RLIM_INFINITY; core_limit.rlim_max = core_limit.rlim_cur; if (setrlimit(RLIMIT_CORE, &core_limit) < 0) { std::string errstr = MdsdUtil::GetErrnoStr(errno); Logger::LogError("Error: setrlimit failed. Reason: " + errstr); } } } // vim: set tabstop=4 softtabstop=4 shiftwidth=4 expandtab : ================================================ FILE: Diagnostic/mdsd/mdsd/wrap_memcpy.c ================================================ /* Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. */ #include /* some systems do not have newest memcpy@@GLIBC_2.14 - stay with old good one */ asm (".symver memcpy, memcpy@GLIBC_2.2.5"); void *__wrap_memcpy(void *dest, const void *src, size_t n) { return memcpy(dest, src, n); } ================================================ FILE: Diagnostic/mdsd/mdsd.8 ================================================ .\"Created with GNOME Manpages Editor Wizard .\"http://sourceforge.net/projects/gmanedit2 .TH mdsd 8 "August 28, 2017" "" "Azure MDS Daemon" .SH NAME mdsd \- azure MDS daemon .SH SYNOPSIS .B mdsd .RI "[-AbCDdjvV] [-c path] [-e path] [-o path] [-p port] [-P proxy_setting] [-r role_name/path] [-T flags] [-w path]" .br .SH DESCRIPTION .PP \fBmdsd\fP is the mandated logging infrastructure for Azure services. It delivers event data (logs, collected metrics, etc) to Azure storage for consumption by various downstream users. .SH OPTIONS .TP .BI "\-b" Don't forward events to storage (blackhole them instead) .TP .BI "\-c " "config file" Specifies the path to the configuration XML file (default /etc/mdsd.d/mdsd.xml) .TP .BI "\-C" Don't suppress core dump when dying due to fatal signals .TP .BI "\-D" Disable logging to files. All log output will instead go to STDERR (fd 2). .TP .BI "\-d" Run mdsd as a daemon .TP .BI "\-e " "log path" Specifies the path to which mdsd error logs are dumped .TP .BI "\-j" Dump all JSON events to stdout as they're received .TP .BI "\-o " "log path" Specifies the path to which received object strings are dumped .TP .BI "\-P " "proxy_setting" Specifies the http proxy which the daemon should use for all outbound http/https connections. An example proxy_setting is something like "http://username:password@proxy_host_name:proxy_port_number". The same can be specified using one of the "MDSD_http_proxy" or "https_proxy" or "http_proxy" environment variables (searched in that order and the first hit is used), and this option (using -P) will override the environment variable (when -P is specified). DO NOT a password on the command line. If a password needs to be given, specify as one of the environment variables mentioned earlier. .TP .BI "\-p " "port" Specifies the port on which the daemon listens for stream connections (0 can be passed as port, in which case a randomly available port will be picked). The port will only be bound to 127.0.0.1 (loopback). If the specified non-zero port is in use, and '-R' is specified, then mdsd will try to bind to a randomly available port instead. Either way, the bound port number will be written to a file whose path is derived from -r info or default (/var/run/mdsd/default.pidport). .TP .BI "\-R " Try binding to a random port if binding to the default/specified port fails. .TP .BI "\-r " "role_name/path" Specifies the role name or file prefix that mdsd will use to construct the paths to the pidport and unix domain socket files. If the argument starts with '/', the value is used as the file prefix; otherwise, the value is used as the role name and the file prefix is '/var/run/mdsd/' + role name. For example, if role name is 'test', then the prefix is '/var/run/mdsd/test'. The pidport file path is 'prefix' + '.pidport'. The unix domain socket files paths are 'prefix' + '_' + 'protocol' + '.socket', where the protocol is 'bond', 'djson', and 'json'. The default paths are: /var/run/mdsd/default.pidport /var/run/mdsd/default_bond.socket /var/run/mdsd/default_djson.socket /var/run/mdsd/default_json.socket .TP .BI "\-S " "directory" Specifies directory to save Event Hub events. syslog user needs to have rwx access to it. If the directory does not exist, mdsd will try to create it. .TP .BI "\-T" Enable tracing for modules selected by flags. .TP .BI "\-v" Validate configuration file and exit .TP .BI "\-V" Print version and exit .TP .BI "\-w " "log path" Specifies the path to which mdsd warning logs are dumped .SH ENVIRONMENT .TP .BI "MDSD_CONFIG_DIR" If set, overrides the default value of "/etc/mdsd.d". .TP .BI "MDSD_RUN_DIR" If set, overrides the default value of "/var/run/mdsd" .TP .BI "MDSD_LOG_DIR" If set, overrides the default value of "/var/log" .SH "SEE ALSO" .BR logger (1), .BR syslog (2), .BR syslog (3) ================================================ FILE: Diagnostic/mdsd/mdsdcfg/CMakeLists.txt ================================================ include_directories( ${CMAKE_SOURCE_DIR}/mdsd ${CMAKE_SOURCE_DIR}/mdsdlog ) set(SOURCES EventPubCfg.cc MdsdEventCfg.cc ) # static lib only add_library(${MDSDCFG_LIB_NAME} STATIC ${SOURCES}) install(TARGETS ${MDSDCFG_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_BINARY_DIR}/release/lib ) ================================================ FILE: Diagnostic/mdsd/mdsdcfg/EventPubCfg.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "EventPubCfg.hh" #include "MdsdEventCfg.hh" #include "Trace.hh" using namespace mdsd; EventPubCfg::EventPubCfg( const std::shared_ptr& mdsdEventCfg ) : m_mdsdEventCfg(mdsdEventCfg), m_dataChecked(false) { if (!mdsdEventCfg) { throw std::invalid_argument("EventPubCfg ctor: invalid NULL pointer for mdsdEventCfg param."); } } void EventPubCfg::AddServiceBusAccount( const std::string & moniker, std::string connStr ) { if (moniker.empty()) { throw std::invalid_argument("AddServiceBusAccount(): moniker param cannot be empty."); } if (connStr.empty()) { throw std::invalid_argument("AddServiceBusAccount(): connStr param cannot be empty."); } // throw if key already exists if (m_sbAccountMap.find(moniker) != m_sbAccountMap.end()) { throw std::runtime_error("AddServiceBusAccount(): key " + moniker + " already exists."); } m_sbAccountMap[moniker] = std::move(connStr); m_dataChecked = false; } void EventPubCfg::AddAnnotationKey( const std::string & publisherName, std::string saskey ) { if (publisherName.empty()) { throw std::invalid_argument("AddAnnotationKey(): publisherName param cannot be empty."); } if (saskey.empty()) { throw std::invalid_argument("AddAnnotationKey(): saskey param cannot be empty."); } // throw if key already exists if (m_annotationKeyMap.find(publisherName) != m_annotationKeyMap.end()) { throw std::runtime_error("AddAnnotationKey(): key " + publisherName + " already exists."); } m_annotationKeyMap[publisherName] = std::move(saskey); m_dataChecked = false; } std::unordered_set EventPubCfg::CheckForInconsistencies( bool hasAutoKey ) { Trace trace(Trace::ConfigLoad, "EventPubCfg::CheckForInconsistencies"); if (m_dataChecked) { TRACEINFO(trace, "EventPubCfg was already checked for inconsistencies. Do nothing."); return std::unordered_set(); } // clear any previous data m_nameMonikers.clear(); m_embeddedSasMap.clear(); std::unordered_set invalidItems; for (const auto & publisherName : m_mdsdEventCfg->GetEventPublishers()) { try { ValidateSasKey(publisherName, hasAutoKey); } catch(const std::exception & ex) { invalidItems.insert(publisherName); } } m_dataChecked = true; DumpEmbeddedSasInfo(); return invalidItems; } void EventPubCfg::ValidateSasKey( const std::string & publisherName, bool hasAutoKey ) { if (publisherName.empty()) { throw std::invalid_argument("ValidateSasKey(): publisherName param cannot be empty."); } auto monikers = m_mdsdEventCfg->GetEventPubMonikers(publisherName); if (monikers.empty()) { throw std::runtime_error("ValidateSasKey(): no moniker is found for publisher " + publisherName); } m_nameMonikers[publisherName] = monikers; if (!hasAutoKey) { ValidateEmbeddedKey(publisherName, monikers); } } void EventPubCfg::ValidateEmbeddedKey( const std::string & publisherName, const std::unordered_set& monikers ) { // The SAS Key should be defined in either // or auto annotationItem = m_annotationKeyMap.find(publisherName); if (annotationItem != m_annotationKeyMap.end()) { // search annotation key first auto & saskey = annotationItem->second; for (const auto & moniker: monikers) { m_embeddedSasMap[publisherName][moniker] = saskey; } } else { // search service bus account info for (const auto & moniker: monikers) { auto sbitem = m_sbAccountMap.find(moniker); if (sbitem != m_sbAccountMap.end()) { m_embeddedSasMap[publisherName][moniker] = sbitem->second; } else { throw std::invalid_argument("ValidateEmbeddedKey(): failed to find EH SAS key for " + publisherName); } } } } void EventPubCfg::DumpEmbeddedSasInfo() { Trace trace(Trace::ConfigLoad, "EventPubCfg::DumpEmbeddedSasInfo"); if (!trace.IsActive()) { return; } if (m_embeddedSasMap.empty()) { TRACEINFO(trace, "EventPublisher map is empty"); } else { for (const auto & iter : m_embeddedSasMap) { auto & publisherName = iter.first; auto & itemsmap = iter.second; if (itemsmap.empty()) { TRACEINFO(trace, "EventPublisher='" << publisherName << "'; Moniker/SAS: N/A."); } else { for (const auto& item : itemsmap) { auto & moniker = item.first; auto & saskey = item.second; TRACEINFO(trace, "EventPublisher='" << publisherName << "'; Moniker='" << moniker << "'; SAS: " << saskey.substr(0, saskey.size()/2)); } } } } } std::unordered_map> EventPubCfg::GetEmbeddedSasData() const { if (!m_dataChecked) { throw std::runtime_error("Check EventPubCfg for inconsistencies before GetEmbeddedSasData()."); } return m_embeddedSasMap; } std::unordered_map> EventPubCfg::GetNameMonikers() const { if (!m_dataChecked) { throw std::runtime_error("Check EventPubCfg for inconsistencies before GetNameMonikers()."); } return m_nameMonikers; } ================================================ FILE: Diagnostic/mdsd/mdsdcfg/EventPubCfg.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTPUBCFG__HH__ #define __EVENTPUBCFG__HH__ #include #include #include #include namespace mdsd { class MdsdEventCfg; /// This class handles event publishing configurations and error detection. /// /// The design is based on the fact that /// - Some errors could only be detected after all configuration data had been gathered. /// - One piece of information (whether mdsd xml uses AutoKey or not), not managed by /// this class and MdsdEventCfg, was needed to that final error detection. /// /// Usage pattern: /// - Add raw configuration data (like service bus accounts, annotation keys). /// Read general event data from MdsdEventCfg. /// - Extract event publisher SAS keys, monikers using CheckForInconsistencies(). /// Handle any inconsistencies. /// If new service bus account, annotation key data are added after /// CheckForInconsistencies(), CheckForInconsistencies() needs to be called again. /// - Use SAS keys, moniker info for event publishing (GetEmbeddedSasData(), /// GetNameMonikers(), etc. /// /// NOTE: this class is not designed for thread-safe. /// class EventPubCfg { public: EventPubCfg(const std::shared_ptr& mdsdEventCfg); ~EventPubCfg() = default; /// /// Save event publisher credential info defined in . /// If the moniker already exists, throw exception. /// void AddServiceBusAccount(const std::string & moniker, std::string connStr); /// /// Save each Event Publisher's SAS key defined in /// If the publisherName already exists, throw exception. /// /// event publisher name. It is source name for non-OMI query, /// or eventName for OMIQuery /// SAS Key for event publishing void AddAnnotationKey(const std::string & publisherName, std::string saskey); /// /// Using SBAccounts, AnnotationKeys and data from mdsdEventCfg, /// extract all publisher names, their monikers and sas keys. /// Return all the invalid publisher names if any. /// NOTE: this API applies to either AutoKey or embedded keys. /// /// If true, validate autokey related info; If false, validate /// embedded keys info. std::unordered_set CheckForInconsistencies(bool hasAutoKey); /// /// Return a map containing moniker, saskey info for each publisher name. /// The saskeys are from embedded keys only. /// map key: publisher name /// map value: a map of /// /// Throw exception if required CheckForInconsistencies() is not called. /// std::unordered_map> GetEmbeddedSasData() const; /// /// Get all the publisher names and their monikers. /// Each publisher has one or more monikers. /// NOTE: this function works for both embedded keys and AutoKeys. /// Return a map with key=publishername; value: monikers /// /// Throw exception if required CheckForInconsistencies() is not called. /// std::unordered_map> GetNameMonikers() const; private: /// /// Get the SAS key for given event publisher, and store the result to _ehPubMap. /// Throw exception if no SAS key or no moniker is found for the event publisher. /// void ValidateSasKey(const std::string & publisherName, bool hasAutoKey); /// /// Validate embedded keys. /// Throw exception if no key is found for given publisher name. /// void ValidateEmbeddedKey(const std::string & publisherName, const std::unordered_set& monikers); /// Dump all embedded sas configuration data for tracing purpose. void DumpEmbeddedSasInfo(); private: std::shared_ptr m_mdsdEventCfg; /// Whether data are checked or not. /// CheckForInconsistencies() must be called before any lookup methods are called. bool m_dataChecked; /// To store Event Publisher connection string defined in /// in mdsd xml. /// map key: moniker; value: event publisher connection string. std::unordered_map m_sbAccountMap; /// To store Event Publisher SAS key defined in in mdsd xml. /// NOTE: for each event publisher defined in EventStreamingAnnotations, the SAS key /// must be defined: /// - For non-Geneva, either ServiceBusAccountInfos or EventStreamingAnnotations. /// - For Geneva, AutoKey only. /// /// map key: publisher name; value: event publisher SAS key. /// publisher name: source name for non-OMIQuery, or eventName for OMIQuery. std::unordered_map m_annotationKeyMap; /// This stores moniker, saskey for each publisher name. /// These information are calculated based on raw xml embedded configurations. /// map key=publisher name; map value=a map of std::unordered_map> m_embeddedSasMap; /// This stores monikers for each publisher name. /// Each publisher has one or more monikers. /// These information are calculated based on raw xml configurations. /// map key = publisher name; map value=monikers std::unordered_map> m_nameMonikers; }; } // namespace #endif // __EVENTPUBCFG__HH__ ================================================ FILE: Diagnostic/mdsd/mdsdcfg/EventSinkCfgInfo.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTSINKCFGINFO__HH__ #define __EVENTSINKCFGINFO__HH__ #include #include "StoreType.hh" #include "EventType.hh" namespace mdsd { // This class is about mdsd event sink/destination configuration info. // It records what's defined in mdsd xml. struct EventSinkCfgInfo { std::string m_eventName; std::string m_moniker; StoreType::Type m_storeType = StoreType::None; std::string m_sourceName; EventType m_eventType; EventSinkCfgInfo(const std::string & eventName, const std::string & moniker, StoreType::Type storeType, const std::string & sourceName, EventType eventType ) : m_eventName(eventName), m_moniker(moniker), m_storeType(storeType), m_sourceName(sourceName), m_eventType(eventType) {} /// Return true if this is a valid entry. Return false otherwise. /// NOTE: sourceName can be empty (e.g. OMIQuery). bool IsValid() const { if (m_moniker.empty() || StoreType::None == m_storeType || (EventType::None == m_eventType && !m_eventName.empty()) || (EventType::None != m_eventType && m_eventName.empty()) ) { return false; } return true; } bool operator==(const EventSinkCfgInfo& other) const { return ((m_eventName == other.m_eventName) && (m_moniker == other.m_moniker) && (m_storeType == other.m_storeType) && (m_sourceName == other.m_sourceName) && (m_eventType == other.m_eventType) ); } bool operator!=(const EventSinkCfgInfo& other) const { return !(*this == other); } // Return the name of the local sink that holds the CanonicalEntities // that are supposed to be pushed to EventHub. // For OMIQuery and DerivedEvent events, this is their event name. // For other events, this is their source name. std::string GetLocalSinkName() const { if (EventType::OMIQuery != m_eventType && EventType::DerivedEvent != m_eventType) { return m_sourceName; } else { return m_eventName; } } }; } // namespace #endif // __EVENTSINKCFGINFO__HH__ ================================================ FILE: Diagnostic/mdsd/mdsdcfg/EventType.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __EVENTTYPE__HH__ #define __EVENTTYPE__HH__ namespace mdsd { // This defines event type specified in mdsd configuration file. enum class EventType { None, OMIQuery, // event defined by RouteEvent, // event defined by DerivedEvent, // event defined by EtwEvent // event defined by }; } // namespace #endif // __EVENTTYPE__HH__ ================================================ FILE: Diagnostic/mdsd/mdsdcfg/MdsdEventCfg.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "MdsdEventCfg.hh" #include "Trace.hh" using namespace mdsd; void MdsdEventCfg::AddEventSinkCfgInfoItem( const EventSinkCfgInfo & item ) { if (!item.IsValid()) { throw std::invalid_argument("MdsdEventCfg::AddEventSinkCfgInfoItem(): item param must be valid."); } m_eventSinkCfgInfoList.push_back(item); m_dataUpdated = true; } void MdsdEventCfg::SetEventAnnotationTypes( std::unordered_map&& eventtypes ) { m_eventAnnotationTypes = std::move(eventtypes); for (const auto & item : m_eventAnnotationTypes) { if (item.second & EventAnnotationType::EventPublisher) { m_eventPublishers.insert(item.first); } } m_dataUpdated = true; } void MdsdEventCfg::UpdateMoniker( const std::string & eventName, const std::string & oldMoniker, const std::string & newMoniker ) { Trace trace(Trace::ConfigUse, "MdsdEventCfg::UpdateMoniker"); if (eventName.empty()) { throw std::invalid_argument("MdsdEventCfg::UpdateMoniker(): eventName param cannot be empty."); } if (oldMoniker.empty()) { throw std::invalid_argument("MdsdEventCfg::UpdateMoniker(): oldMoniker param cannot be empty."); } if (newMoniker.empty()) { throw std::invalid_argument("MdsdEventCfg::UpdateMoniker(): newMoniker param cannot be empty."); } for (auto & item : m_eventSinkCfgInfoList) { if (eventName == item.m_eventName && oldMoniker == item.m_moniker) { item.m_moniker = newMoniker; m_dataUpdated = true; } } } std::unordered_set MdsdEventCfg::GetInvalidAnnotations() { ExtractEventCfg(); std::unordered_set result; for (const auto & item : m_eventAnnotationTypes) { auto & name = item.first; auto & anntype = item.second; if (EventAnnotationType::EventPublisher == anntype) { if (!m_ehpubMonikers.count(name)) { result.insert(name); } } else { if (!m_eventNames.count(name)) { result.insert(name); } } } return result; } void MdsdEventCfg::ExtractEventCfg() { if (!m_dataUpdated) { return; } Trace trace(Trace::ConfigUse, "MdsdEventCfg::ExtractEventCfg"); // Clean any previous data if any m_eventNames.clear(); m_ehpubMonikers.clear(); m_ehMonikers.clear(); auto publishers = GetEventPublishers(); for (const auto & item : m_eventSinkCfgInfoList) { auto & eventname = item.m_eventName; auto & moniker = item.m_moniker; auto & storetype = item.m_storeType; m_eventNames.insert(eventname); auto localSinkName = item.GetLocalSinkName(); m_ehpubMonikers[localSinkName].insert(moniker); if (storetype == StoreType::Bond) { m_ehMonikers.insert(moniker); } else if (storetype == StoreType::Local) { if (publishers.count(localSinkName)) { m_ehMonikers.insert(moniker); } } } m_dataUpdated = false; } std::unordered_map> MdsdEventCfg::GetCentralBondEvents() const { std::unordered_map> cbEvents; for (const auto & item : m_eventSinkCfgInfoList) { if (StoreType::Bond == item.m_storeType) { cbEvents[item.m_eventName].insert(item.m_moniker); } } return cbEvents; } std::unordered_set MdsdEventCfg::GetEventPubMonikers( const std::string & publisherName ) { if (publisherName.empty()) { throw std::invalid_argument("GetEventPubMonikers(): publisherName param cannot be empty."); } ExtractEventCfg(); auto item = m_ehpubMonikers.find(publisherName); if (item != m_ehpubMonikers.end()) { return item->second; } return std::unordered_set(); } ================================================ FILE: Diagnostic/mdsd/mdsdcfg/MdsdEventCfg.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSDEVENTCFG__HH__ #define __MDSDEVENTCFG__HH__ #include #include #include #include #include "EventSinkCfgInfo.hh" #include "CfgEventAnnotationType.hh" namespace mdsd { enum class EventType; /// This class handles general mdsd event configurations. /// Usage pattern: /// - Collect raw event info data: AddEventSinkCfgInfoItem(), SetEventAnnotationTypes(), /// UpdateMoniker(), etc. /// - Use aggregated results: GetCentralBondEvents(), IsEventHubEnabled(), etc. /// The event configuration data are lazily extracted and aggregated at "Get" time. /// /// NOTE: This class is not designed for thread-safe. /// class MdsdEventCfg { public: MdsdEventCfg() = default; ~MdsdEventCfg() = default; /// /// Add an eventSinkCfgInfo object to internal data structure if it is valid. /// Throw exception if eventSinkCfgInfo is invalid. /// void AddEventSinkCfgInfoItem(const EventSinkCfgInfo & item); /// /// Set event annotation types object. /// void SetEventAnnotationTypes(std::unordered_map&& eventtypes); /// /// For all m_eventSinkCfgInfoList entries where eventName='eventName' /// and moniker='oldMoniker', update moniker to 'newMoniker'. /// Throw exception if any input parameter string is empty. /// void UpdateMoniker(const std::string & eventName, const std::string & oldMoniker, const std::string & newMoniker); /// /// Get a map of for all CentralBond store type events. /// std::unordered_map> GetCentralBondEvents() const; /// /// Return the names of all event publishers in mdsd xml . /// This includes anything that could be invalid if any. /// std::unordered_set GetEventPublishers() const { return m_eventPublishers; } /// /// Return all the monikers used by given publisherName, which can be either a source name, /// or an EventName (e.g. OMIQuery or DerivedEvent). /// Return empty set if publisherName is not found /// std::unordered_set GetEventPubMonikers(const std::string & publisherName); /// /// Get invalid names in mdsd xml /// std::unordered_set GetInvalidAnnotations(); /// /// Returns boolean specifying whether provided moniker (input parameter) /// has a companion Event Hub. /// bool IsEventHubEnabled(const std::string & moniker) { ExtractEventCfg(); return m_ehMonikers.count(moniker); } size_t GetNumEventSinkCfgInfoItems() const { return m_eventSinkCfgInfoList.size(); } private: /// /// Extract event configuration data and store them to internal data structures. /// - a set to store all the event names. /// - publisher name -> monikers map for all events. /// - All monikers that are used by EventHub notice or Event publishing. /// void ExtractEventCfg(); private: /// Whether any config data are updated bool m_dataUpdated = false; /// Store information about all the events in mdsd xml file. std::vector m_eventSinkCfgInfoList; /// Store all the eventNames std::unordered_set m_eventNames; /// This map tracks all the EventHub publication monikers to which each new /// CanonicalEvent, when added to the LocalSink, should be published. /// /// map key: LocalSink name /// map value: all the monikers used by the LocalSink std::unordered_map> m_ehpubMonikers; /// key: item name; value: EventAnnotationType std::unordered_map m_eventAnnotationTypes; /// Store all the event publisher names. std::unordered_set m_eventPublishers; /// Store the moniker names when EventHub is enabled on the moniker: /// A companion Event Hub exists if /// - a moniker has an event of store type 'CentralBond' /// - a moniker has an event of store type 'Local', which is also listed /// under EventStreamingAnnotation as an EventPublisher. std::unordered_set m_ehMonikers; }; } // namespace #endif // __MDSDEVENTCFG__HH__ ================================================ FILE: Diagnostic/mdsd/mdsdinput/CMakeLists.txt ================================================ set(SOURCES mdsd_input_types.cpp mdsd_input_apply.cpp MdsdInputSchemaCache.cpp MdsdInputMessageBuilder.cpp MdsdInputMessageIO.cpp ) add_library(${INPUT_LIB_NAME} STATIC ${SOURCES}) install(TARGETS ${INPUT_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_BINARY_DIR}/release/lib ) ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputMessageBuilder.cpp ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsdInputMessageBuilder.h" #include "bond/core/bond.h" #include #include namespace mdsdinput { void MessageBuilder::MessageBegin() { _output.reset(new bond::OutputBuffer(_buffer, BUFFER_SIZE)); _writer.reset(new bond::SimpleBinaryWriter(*(_output.get()))); _schema = std::make_shared(); } std::shared_ptr MessageBuilder::MessageEnd(const std::string& source) { auto id = (_schema_cache->AddSchema(_schema)).first; auto msg = std::make_shared(); msg->schemaId = id; msg->source = source; auto out = _output->GetBuffer(); auto buf = boost::allocate_shared_noinit(std::allocator(), out.length()); std::copy(out.begin(), out.end(), buf.get()); msg->data.assign(buf, out.length()); return msg; } void MessageBuilder::AddBool(const std::string& name, bool value) { FieldDef fd; fd.name = name; fd.fieldType = FT_BOOL; _schema->fields.push_back(fd); _writer->Write(value); } void MessageBuilder::AddInt32(const std::string& name, int32_t value) { FieldDef fd; fd.name = name; fd.fieldType = FT_INT32; _schema->fields.push_back(fd); _writer->Write(value); } void MessageBuilder::AddInt64(const std::string& name, int64_t value) { FieldDef fd; fd.name = name; fd.fieldType = FT_INT64; _schema->fields.push_back(fd); _writer->Write(value); } void MessageBuilder::AddDouble(const std::string& name, double value) { FieldDef fd; fd.name = name; fd.fieldType = FT_DOUBLE; _schema->fields.push_back(fd); _writer->Write(value); } void MessageBuilder::AddTime(const std::string& name, const Time& value, bool isTimestampField) { FieldDef fd; fd.name = name; fd.fieldType = FT_TIME; if (isTimestampField) { _schema->timestampFieldIdx.set(static_cast(_schema->fields.size())); } _schema->fields.push_back(fd); _writer->Write(value.sec); _writer->Write(value.nsec); } void MessageBuilder::AddString(const std::string& name, const std::string& value) { FieldDef fd; fd.name = name; fd.fieldType = FT_STRING; _schema->fields.push_back(fd); _writer->Write(value); } } ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputMessageBuilder.h ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #include "mdsd_input_reflection.h" #include #include "MdsdInputSchemaCache.h" namespace mdsdinput { class MessageBuilder { public: static const size_t BUFFER_SIZE = 32 * 1024; MessageBuilder() : _schema_cache(std::make_shared()) , _buffer(boost::make_shared_noinit(BUFFER_SIZE)) {} MessageBuilder(std::shared_ptr& schemaCache) : _schema_cache(schemaCache) , _buffer(boost::make_shared_noinit(BUFFER_SIZE)) {} MessageBuilder(const MessageBuilder&) = delete; MessageBuilder(MessageBuilder&&) = default; MessageBuilder& operator=(const MessageBuilder&) = delete; MessageBuilder& operator=(MessageBuilder&&) = default; std::shared_ptr GetSchemaCache() { return _schema_cache; } // Start a new message. All previous data is discarded. void MessageBegin(); // Return a constructed message. std::shared_ptr MessageEnd(const std::string& source); void AddBool(const std::string& name, bool value); void AddInt32(const std::string& name, int32_t value); void AddInt64(const std::string& name, int64_t value); void AddDouble(const std::string& name, double value); void AddTime(const std::string& name, const Time& value, bool isTimestampField); void AddString(const std::string& name, const std::string& value); protected: std::shared_ptr _schema_cache; std::shared_ptr _schema; boost::shared_ptr _buffer; std::unique_ptr _output; std::unique_ptr > _writer; }; } ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputMessageDecoder.h ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #include "mdsd_input_reflection.h" #include #include "MdsdInputSchemaCache.h" #include "bond/core/bond.h" #include "bond/stream/input_buffer.h" #include "bond/protocol/simple_binary.h" #include #include namespace mdsdinput { class MessageDecoder { public: MessageDecoder() : _schema_cache(std::make_shared()) {} MessageDecoder(std::shared_ptr& schemaCache) : _schema_cache(schemaCache) {} template ResponseCode Decode(const Message& msg, FieldReceiver& receiver) { std::shared_ptr schema; if (!msg.schema.empty()) { schema = std::make_shared(msg.schema.value()); if (!_schema_cache->AddSchemaWithId(schema, msg.schemaId)) { return ACK_DUPLICATE_SCHEMA_ID; } } else { try { schema = _schema_cache->GetSchema(msg.schemaId); } catch (std::out_of_range ex) { return ACK_UNKNOWN_SCHEMA_ID; } } bond::SimpleBinaryReader reader(msg.data); int32_t idx = 0; for (auto it = schema->fields.begin(); it != schema->fields.end(); ++it, ++idx) { try { switch (it->fieldType) { case FT_INVALID: return ACK_DECODE_ERROR; case FT_BOOL: { bool b; reader.Read(b); receiver.BoolField(it->name, b); break; } case FT_INT32: { int32_t i; reader.Read(i); receiver.Int32Field(it->name, i); break; } case FT_INT64: { int64_t i; reader.Read(i); receiver.Int64Field(it->name, i); break; } case FT_DOUBLE: { double d; reader.Read(d); receiver.DoubleField(it->name, d); break; } case FT_TIME: { Time t; reader.Read(t.sec); reader.Read(t.nsec); receiver.TimeField(it->name, t, (!schema->timestampFieldIdx.empty() && *(schema->timestampFieldIdx) == static_cast(idx))); break; } case FT_STRING: { std::string str; reader.Read(str); receiver.StringField(it->name, str); break; } } } catch (bond::StreamException& ex) { return ACK_DECODE_ERROR; } } return ACK_SUCCESS; } std::shared_ptr GetSchema(uint64_t id) { return _schema_cache->GetSchema(id); } std::string GetSchemaKey(uint64_t id) { return _schema_cache->GetSchemaKey(id); } protected: std::shared_ptr _schema_cache; }; } ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputMessageIO.cpp ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsdInputMessageIO.h" #include extern "C" { #include } #include #include #include namespace mdsdinput { void FDIO::Write(const bond::blob& blob) { Write(blob.data(), blob.size()); } void FDIO::Read(bond::blob& blob, uint32_t size) { auto data = boost::allocate_shared_noinit(std::allocator(), size); Read(data.get(), size); blob.assign(data, size); } void FDIO::Read(void *buffer, uint32_t size) { assert(buffer != nullptr); size_t nleft = size; do { errno = 0; ssize_t nr = read(_fd, reinterpret_cast(buffer) + (size - nleft), nleft); if (nr < 0) { if (EINTR != errno) { throw std::system_error(errno, std::system_category()); } } else { nleft -= nr; if (nleft > 0 && nr == 0) { throw eof_exception(); } } } while (nleft > 0); } void FDIO::Write(const void *buffer, uint32_t size) { assert(buffer != nullptr); size_t nleft = size; do { errno = 0; ssize_t nw = write(_fd, reinterpret_cast(buffer)+(size - nleft), nleft); if (nw < 0) { if (EINTR != errno) { throw std::system_error(errno, std::system_category()); } } else if (nw == 0) { throw std::runtime_error("write() returned 0"); } else { nleft -= nw; } } while (nleft > 0); } template void FDIO::Read(bool&); template void FDIO::Read(int32_t&); template void FDIO::Read(int64_t&); template void FDIO::Read(double&); template void FDIO::Write(bool); template void FDIO::Write(int32_t); template void FDIO::Write(int64_t); template void FDIO::Write(double); template class MessageIO; } ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputMessageIO.h ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #include "mdsd_input_reflection.h" #include "bond/core/bond.h" #include "bond/stream/input_buffer.h" #include "bond/stream/output_buffer.h" #include "bond/protocol/simple_binary.h" namespace mdsdinput { constexpr uint32_t MAX_MESSAGE_SIZE = 64 * 1024; class eof_exception : public std::runtime_error { public: eof_exception() : std::runtime_error("Connection Closed") {} }; class msg_too_large_error : public std::runtime_error { public: explicit msg_too_large_error(const std::string& msg) : std::runtime_error(msg) {} }; class FDIO { public: explicit FDIO(int fd) : _fd(fd) {} // Read overload(s) for arithmetic types template typename boost::enable_if >::type Read(T& value) { Read(reinterpret_cast(&value), sizeof(value)); } // Write overload(s) for arithmetic types template typename boost::enable_if >::type Write(T value) { Write(reinterpret_cast(&value), sizeof(value)); } // Read into a memory blob void Read(bond::blob& blob, uint32_t size); // Write a memory blob void Write(const bond::blob& blob); // Read into a memory buffer void Read(void *buffer, uint32_t size); // Write a memory buffer void Write(const void *buffer, uint32_t size); protected: int _fd; }; extern template void FDIO::Read(bool&); extern template void FDIO::Read(int32_t&); extern template void FDIO::Read(int64_t&); extern template void FDIO::Read(double&); extern template void FDIO::Write(bool); extern template void FDIO::Write(int32_t); extern template void FDIO::Write(int64_t); extern template void FDIO::Write(double); template class MessageIO { public: MessageIO(IO& io) : _io(io) {} void ReadMessage(Message& msg) { uint32_t size = 0; _io.Read(size); if (size > MAX_MESSAGE_SIZE) { throw msg_too_large_error(""); } bond::blob data; _io.Read(data, size); bond::SimpleBinaryReader input(data); bond::Deserialize(input, msg); } void WriteMessage(const Message& msg) { bond::OutputBuffer obuf; bond::SimpleBinaryWriter output(obuf); bond::Serialize(msg, output); bond::blob data = obuf.GetBuffer(); uint32_t size = data.size(); _io.Write(size); _io.Write(data); } void ReadAck(Ack& ack) { _io.Read(ack.msgId); uint32_t code = 0; _io.Read(code); ack.code = static_cast(code); } void WriteAck(const Ack& ack) { _io.Write(ack.msgId); _io.Write(static_cast(ack.code)); } protected: IO _io; }; extern template class MessageIO; } ================================================ FILE: Diagnostic/mdsd/mdsdinput/MdsdInputSchemaCache.cpp ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "MdsdInputSchemaCache.h" #include "bond/core/apply.h" #include "bond/core/runtime_schema.h" #include "bond/core/schema.h" #include "stdio.h" namespace mdsdinput { std::pair SchemaCache::AddSchema(const std::shared_ptr& schema) { auto key = schemaKey(schema); std::lock_guard lock(_lock); auto sk = _schema_ids.find(key); if (sk != _schema_ids.end()) { return std::make_pair((*sk).second, false); } uint64_t id = _next_id++; _schemas.insert(std::make_pair(id, schema)); _schema_ids.insert(std::make_pair(key, id)); _schema_keys.insert(std::make_pair(id, key)); return std::make_pair(id, true); } bool SchemaCache::AddSchemaWithId(const std::shared_ptr& schema, uint64_t id) { auto key = schemaKey(schema); std::lock_guard lock(_lock); auto it = _schema_keys.find(id); if (it != _schema_keys.end()) { if (it->second == key) { return true; } else { return false; } } _schemas.insert(std::make_pair(id, schema)); _schema_ids.insert(std::make_pair(key, id)); _schema_keys.insert(std::make_pair(id, key)); return true; } std::shared_ptr SchemaCache::GetSchema(uint64_t id) { std::lock_guard lock(_lock); return _schemas.at(id); } std::string SchemaCache::GetSchemaKey(uint64_t id) { std::lock_guard lock(_lock); return _schema_keys.at(id); } std::string SchemaCache::schemaKey(const std::shared_ptr& schema) { std::string key; if (!schema->timestampFieldIdx.empty()) { key.append(std::to_string(*(schema->timestampFieldIdx))); } for (const auto & it : schema->fields) { key.append(ToString(it.fieldType)); key.append(it.name); } return key; } static boost::shared_ptr makeBondSchema(const std::shared_ptr& s) { auto bs = boost::make_shared (); bond::StructDef st; bool _time_added = false; uint16_t _time_id = 1; // Time is always the second struct and thus has ID 1. uint16_t id = 0; for (const auto & it : s->fields) { bond::FieldDef f; f.id = id; id++; f.metadata.name = it.name; switch (it.fieldType) { case FT_INVALID: throw std::runtime_error("FT_INVALID encountered!"); case FT_BOOL: f.type.id = bond::BT_BOOL; break; case FT_INT32: f.type.id = bond::BT_INT32; break; case FT_INT64: f.type.id = bond::BT_INT64; break; case FT_DOUBLE: f.type.id = bond::BT_DOUBLE; break; case FT_TIME: f.type.id = bond::get_type_id template T RotateLeft(T n, unsigned int count) { count = count%(sizeof(T)*8); if (count == 0) return n; return RotateRight(n, sizeof(T)*8 - count); } /// Compute 64-bit Murmur hash of a string, with initializer unsigned long long MurmurHash64(const std::string&, unsigned long); /// Convert a POSIX errno to a string std::string GetErrnoStr(int errnum); inline std::string ToString(bool b) { return b? "true" : "false"; } class would_block : public std::exception { public: virtual const char* what() const noexcept { return "EWOULDBLOCK"; } }; /// Write a buffer, followed by a newline, to a POSIX file descriptor. Throw appropriate exceptions /// for short writes or any error reported by writev. void WriteBufferAndNewline(int fd, const char * buf, size_t len); void WriteBufferAndNewline(int fd, const char * buf); void WriteBufferAndNewline(int fd, const std::string& buf); /// Convert a UTF-8 std::string to a std::wstring, encoded in UTF-16, relying on /// the cpprest library to convert to utf16 in a u16string and copying characters. std::wstring to_utf16(const std::string& s); /// /// Create a directory given its path if it doesn't exist. /// Throw exception if any error. /// Return true if the directory doesn't exist and is created properly. /// Return false if the directory is valid and already exists. /// NOTE: the mode is used only when directory is created in this function. /// bool CreateDirIfNotExists(const std::string& filepath, mode_t mode); /// /// Extracts and returns the storage account name from the passed storage endpoint URL. /// For example, returns "stgacct", given "https://stgacct.blob.core.windows.net/". /// If no match is found, an empty string is returned. /// std::string GetStorageAccountNameFromEndpointURL(const std::string& url); /// /// Get the value of a variable from the process environment. Throw std::runtime_error /// if the variable is not defined in the environment. This is different from the variable /// being defined as an empty string; that latter case does not throw an error. /// std::string GetEnvironmentVariable(const std::string &); /// /// Get the value of a variable from the process environment. Does not throw an exception /// if the variable is not defined in the environment; in that case it returns an empty string. /// std::string GetEnvironmentVariableOrEmpty(const std::string &); /// Returns the hostname of the running system std::string GetHostname(); /// Get autokey table's 10-day suffix std::string GetTenDaySuffix(); /// /// Return true if filepath exists and it is a regular file. /// If filepath is an empty string, throw exception. /// bool IsRegFileExists(const std::string & filepath); /// /// Return true if filepath exists and it is a directory. /// If filepath is an empty string, throw exception. /// bool IsDirExists(const std::string & filepath); /// /// Make sure that the filepath exists, is a dir, and the running process has /// read/write/execute access to the dir. /// Throw exception otherwise. /// void ValidateDirRWXByUser(const std::string & filepath); /// /// If 'filepath' exists, unlink it. /// Return true if no error and the file is unlinked. /// Return false if the file doesn't exist. /// Throw exception for any error. /// bool RemoveFileIfExists(const std::string & filepath); /// /// Rename file from 'oldpath' to 'newpath' if 'oldpath' exists. /// Return true if no error and the file is successfully renamed. /// Return false if the file doesn't exist. /// Throw exception if any error. /// bool RenameFileIfExists(const std::string & oldpath, const std::string & newpath); /// /// Copy file 'frompath' to 'topath'. If 'topath' exists, it will be overwritten. /// It will throw exception for any error. /// void CopyFile(const std::string & frompath, const std::string & topath); time_t GetLastModificationTime(const std::string & filename); /// /// Get the last modified file in a given file list. /// If the list is empty, throw exception. /// If there are more than one files that meet this criteria, return the first /// one in the list. /// std::string GetMostRecentlyModifiedFile(const std::vector & filelist); /// /// change a file's last modification time to 'now' at micro-second precision. /// void TouchFileUs(const std::string & filename); /// Block or unblock a given signal. void MaskSignal(bool isBlock, int signum); /// Get the basename of a filepath std::string GetFileBasename(const std::string & filepath); /// Utility class to open a file with exclusive lock, allow /// writing to it line-by-line, and let the destructor delete the file class LockedFile { std::string m_filepath; int m_fd; public: LockedFile() : m_fd(-1) {} LockedFile(const std::string& filepath); ~LockedFile(); LockedFile(const LockedFile&) = delete; LockedFile(LockedFile&&) = default; LockedFile& operator=(const LockedFile&) = delete; LockedFile& operator=(LockedFile&&) = default; void Open(const std::string& filepath); bool IsOpen() const { return !m_filepath.empty(); } void WriteLine(const std::string& line) const; void Remove(); void TruncateAndClose(); class AlreadyLocked : public std::runtime_error { public: AlreadyLocked(const std::string& msg) : std::runtime_error(msg) {} }; }; /// Copy maximum of 'maxbytes' from 'src' and return the result string. /// If src is NULL or maxbytes is 0, return empty string. /// If maxbytes > src's length, return a duplicate of src. std::string StringNCopy(const char* src, size_t maxbytes); /// Return current thread id as a string. std::string GetTid(); class FdCloser { public: explicit FdCloser(int fd) : m_fd(fd) {} ~FdCloser(); void Release(); private: int m_fd; }; class FileCloser { public: FileCloser(FILE* fp) : m_fp(fp) {} ~FileCloser() { if (m_fp) { fclose(m_fp); m_fp = nullptr; } } private: FILE* m_fp; }; /// Get the resource limit for number of open files for current process. /// Return 0 if infinity, return the actual number othwerwise. int32_t GetNumFileResourceSoftLimit(); /// Get syslog severity string from numeric value. E.g., for 5, it's "Notice" const char* GetSyslogSeverityStringFromValue(int severity); /// /// Create a UNIX socket using given file path, then bind to it. /// Throw exception if any error. /// Return the socket fd. /// int CreateAndBindUnixSocket(const std::string & sockFilePath); /// /// Return the named environment variable value or the default_value if the variable isn't present. /// If the path specified by the value doesn't exist, throw a runtime_error. /// std::string GetEnvDirVar(const std::string& name, const std::string& default_value); /// /// Parse an absolute https:// or http:// URL in the format of "http(s)://xxx/yyy" /// Return "http(s)://xxx" as baseUrl, "/yyy" as params. /// If URL format is "http(s)://xxx", return "http(s)://xxx" as baseUrl, "" as params. /// Throw exception for invalid format absUrl. /// void ParseHttpsOrHttpUrl(const std::string & absUrl, std::string& baseUrl, std::string& params); } #endif // _UTILITY_HH_ // vim: se sw=8 ================================================ FILE: Diagnostic/mdsd/mdsrest/CMakeLists.txt ================================================ include_directories( ${CMAKE_SOURCE_DIR}/mdsdlog ${CMAKE_SOURCE_DIR}/mdsdutil ${CASABLANCA_INCLUDE_DIRS} ${STORAGE_INCLUDE_DIRS} ) set(SOURCES GcsJsonData.cc GcsJsonParser.cc GcsServiceInfo.cc GcsUtil.cc MdsRest.cc ) # static lib only add_library(${MDSREST_LIB_NAME} STATIC ${SOURCES}) install(TARGETS ${MDSREST_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_BINARY_DIR}/release/lib ) ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsJsonData.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "GcsJsonData.hh" #include "GcsUtil.hh" #include "Logger.hh" #include "MdsConst.hh" using namespace mdsd; static std::string GetStringFromJson( const std::string & itemname, const web::json::value& jsonObj ) { GcsUtil::ThrowIfInvalidType(itemname, web::json::value::String, jsonObj.type()); return jsonObj.as_string(); } std::ostream& mdsd::operator<<(std::ostream & os, const EventHubKey& obj) { os << " SasKey='" << obj.SasKey << "'; Uri='" << obj.Uri << "'.\n"; return os; } std::unordered_map> EventHubKey::ParserMap = { { "SasKey", [](const std::string & name, const web::json::value & value, EventHubKey& result) { result.SasKey = GetStringFromJson(name, value); } }, { "Uri", [](const std::string & name, const web::json::value & value, EventHubKey& result) { result.Uri = GetStringFromJson(name, value); } } }; std::ostream& mdsd::operator<<(std::ostream & os, const ServiceBusAccountKey& obj) { os << "AccountGroupName='" << obj.AccountGroupName << "'; AccountMonikerName='" << obj.AccountMonikerName << "'.\n"; for (const auto & item : obj.EventHubKeys) { os << "EventHubsKeys: " << item.first << ":" << item.second; } return os; } // Instead of return when an invalid item is found, the validation will do // as much validation as possible. bool ServiceBusAccountKey::IsValid() const { bool retVal = true; if (AccountGroupName.empty() || AccountMonikerName.empty() || EventHubKeys.empty()) { Logger::LogError("Error: ServiceBusAccountKey has invalid empty field"); retVal = false; } for (const auto & item : EventHubKeys) { if (!item.second.IsValid()) { Logger::LogError("Error: EventHubKey '" + item.first + "' is invalid"); retVal = false; } } return retVal; } std::unordered_map> ServiceBusAccountKey::ParserMap = { { "AccountGroupName", [](const std::string & name, const web::json::value & value, ServiceBusAccountKey& result) { result.AccountGroupName = GetStringFromJson(name, value); } }, { "AccountMonikerName", [](const std::string & name, const web::json::value & value, ServiceBusAccountKey& result) { result.AccountMonikerName = GetStringFromJson(name, value); } }, { "EventHubKeys", [](const std::string & name, const web::json::value & value, ServiceBusAccountKey& result) { details::EventHubKeysParser ehkeysParser(name, value); ehkeysParser.Parse(result.EventHubKeys); } } }; std::unordered_map> StorageSasKey::ParserMap = { { "ResourceName", [](const std::string & name, const web::json::value & value, StorageSasKey& result) { result.ResourceName = GetStringFromJson(name, value); } }, { "SasKey", [](const std::string & name, const web::json::value & value, StorageSasKey& result) { result.SasKey = GetStringFromJson(name, value); } }, { "SasKeyType", [](const std::string & name, const web::json::value & value, StorageSasKey& result) { result.SasKeyType = GetStringFromJson(name, value); } } }; std::ostream& mdsd::operator<<(std::ostream & os, const StorageSasKey& obj) { os << "ResourceName='" << obj.ResourceName << "'; SasKey='" << obj.SasKey << "'; SasKeyType='" << obj.SasKeyType << "'\n"; return os; } std::ostream& mdsd::operator<<(std::ostream & os, const StorageAccountKey& obj) { os << "StorageAccountName='" << obj.StorageAccountName << "'; " << "AccountGroupName='" << obj.AccountGroupName << "'; " << "AccountMonikerName='" << obj.AccountMonikerName << "'; " << "BlobEndpoint='" << obj.BlobEndpoint << "'; " << "QueueEndpoint='" << obj.QueueEndpoint << "'; " << "TableEndpoint='" << obj.TableEndpoint << "'.\n"; for (const auto & item : obj.SasKeys) { os << item; } return os; } // Return true if equal, false if not equal. // Log error if not equal. static inline bool ValidateEqual( int expected, int actual, const std::string & msg ) { if (expected != actual) { std::ostringstream ostr; ostr << "Error: " << msg << ": expected=" << expected << "; actual=" << actual; Logger::LogError(ostr); return false; } return true; } bool StorageAccountKey::IsValid() const { bool retVal = true; if (StorageAccountName.empty() || AccountGroupName.empty() || AccountMonikerName.empty() || BlobEndpoint.empty() || QueueEndpoint.empty() || TableEndpoint.empty() || SasKeys.empty()) { Logger::LogError("Error: StorageAccountKey has invalid empty field"); retVal = false; } // The Blob and Table SAS keys must be defined exactly once int nBlobSas = 0; int nTableSas = 0; const int nexpected = 1; for (const auto & item : SasKeys) { if (!item.IsValid()) { retVal = false; } if ("BlobService" == item.SasKeyType) { nBlobSas++; } else if ("TableService" == item.SasKeyType) { nTableSas++; } } retVal &= ValidateEqual(nexpected, nBlobSas, "# of BlobService SasKeys"); retVal &= ValidateEqual(nexpected, nTableSas, "# of TableService SasKeys"); return retVal; } std::unordered_map> StorageAccountKey::ParserMap = { { "StorageAccountName", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.StorageAccountName = GetStringFromJson(name, value); } }, { "AccountGroupName", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.AccountGroupName = GetStringFromJson(name, value); } } , { "AccountMonikerName", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.AccountMonikerName = GetStringFromJson(name, value); } } , { "BlobEndpoint", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.BlobEndpoint = GetStringFromJson(name, value); } } , { "QueueEndpoint", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.QueueEndpoint = GetStringFromJson(name, value); } } , { "TableEndpoint", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { result.TableEndpoint = GetStringFromJson(name, value); } }, { "SasKeys", [](const std::string & name, const web::json::value & value, StorageAccountKey& result) { details::ObjectArrayParser arrayParser(name, value); arrayParser.Parse(result.SasKeys); } } }; std::ostream& mdsd::operator<<(std::ostream & os, const GcsAccount& obj) { os << "\nMaSigningPublicKeys: " << obj.MaSigningPublicKeys.size() << "\n"; for (const auto & item : obj.MaSigningPublicKeys) { os << item; } os << "SasKeysExpireTimeUtc='" << obj.SasKeysExpireTimeUtc << "';\n"; for (const auto & item: obj.ServiceBusAccountKeys) { os << item; } for (const auto & item: obj.StorageAccountKeys) { os << item; } os << "TagId='" << obj.TagId << "'."; return os; } static bool ValidateMaSigningPublicKeys( const std::vector & MaSigningPublicKeys ) { bool retVal = true; if (MaSigningPublicKeys.empty()) { Logger::LogError("Error: unexpected empty MaSigningPublicKey array"); retVal = false; } for (const auto & item: MaSigningPublicKeys) { if (item.empty()) { Logger::LogError("Error: unexpected invalid MaSigningPublicKey"); retVal = false; } } return retVal; } static bool ValidateServiceBusAccountKeys( const std::vector & ServiceBusAccountKeys ) { bool retVal = true; if (ServiceBusAccountKeys.empty()) { Logger::LogError("Error: unexpected empty ServiceBusAccountKeys array"); retVal = false; } size_t i = 0; for (const auto & item : ServiceBusAccountKeys) { if (!item.IsValid()) { Logger::LogError("Error: ServiceBusAccountKeys[" + std::to_string(i) + "] is invalid"); retVal = false; } i++; } // Validate that required EventHub keys exist if (!ServiceBusAccountKeys.empty()) { int nEHNoticeKeys = 0; int nEHPublishKeys = 0; const int nexpected = 1; for (auto & item : ServiceBusAccountKeys) { if (item.EventHubKeys.count(gcs::c_EventHub_notice)) { nEHNoticeKeys++; } if (item.EventHubKeys.count(gcs::c_EventHub_publish)) { nEHPublishKeys++; } } retVal &= ValidateEqual(nexpected, nEHNoticeKeys, "# EventHubKey for '" + gcs::c_EventHub_notice + "'"); retVal &= ValidateEqual(nexpected, nEHPublishKeys, "# EventHubKey for '" + gcs::c_EventHub_publish + "'"); } return retVal; } static bool ValidateStorageAccountKeys( const std::vector & StorageAccountKeys ) { bool retVal = true; if (StorageAccountKeys.empty()) { Logger::LogError("Error: unexpected empty StorageAccountKeys array"); retVal = false; } size_t i = 0; for (const auto & item : StorageAccountKeys) { if (!item.IsValid()) { Logger::LogError("Error: StorageAccountKeys[" + std::to_string(i) + "] is invalid"); retVal = false; } i++; } return retVal; } bool GcsAccount::IsValid() const { bool retVal = true; if (TagId.empty()) { retVal = false; } if (IsEmpty()) { return retVal; } retVal &= ValidateMaSigningPublicKeys(MaSigningPublicKeys); if (SasKeysExpireTimeUtc.empty()) { Logger::LogError("Error: unexpected empty SasKeysExpireTimeUtc"); retVal = false; } retVal &= ValidateServiceBusAccountKeys(ServiceBusAccountKeys); retVal &= ValidateStorageAccountKeys(StorageAccountKeys); return retVal; } bool GcsAccount::IsEmpty() const { return ( MaSigningPublicKeys.empty() && SasKeysExpireTimeUtc.empty() && ServiceBusAccountKeys.empty() && StorageAccountKeys.empty() ); } std::unordered_map> GcsAccount::ParserMap = { { "MaSigningPublicKeys", [](const std::string & name, const web::json::value & value, GcsAccount& result) { if (!value.is_null()) { details::StringArrayParser parser(name, value); parser.Parse(result.MaSigningPublicKeys); } } }, { "SasKeysExpireTimeUtc", [](const std::string & name, const web::json::value & value, GcsAccount& result) { if (!value.is_null()) { result.SasKeysExpireTimeUtc = GetStringFromJson(name, value); } } }, { "ServiceBusAccountKeys" , [](const std::string & name, const web::json::value & value, GcsAccount& result) { if (!value.is_null()) { details::ObjectArrayParser parser(name, value); parser.Parse(result.ServiceBusAccountKeys); } } }, { "StorageAccountKeys", [](const std::string & name, const web::json::value & value, GcsAccount& result) { if (!value.is_null()) { details::ObjectArrayParser parser(name, value); parser.Parse(result.StorageAccountKeys); } } }, { "TagId", [](const std::string & name, const web::json::value & value, GcsAccount& result) { result.TagId = GetStringFromJson(name, value); } } }; ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsJsonData.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __GCSJSONDATA_HH__ #define __GCSJSONDATA_HH__ #include #include #include #include #include #include #include "GcsJsonParser.hh" namespace mdsd { template using itemparser_t = std::function; struct EventHubKey { std::string SasKey; std::string Uri; bool IsValid() const { return !SasKey.empty() && !Uri.empty(); } static std::unordered_map> ParserMap; }; std::ostream& operator<<(std::ostream & os, const EventHubKey& obj); struct ServiceBusAccountKey { std::string AccountGroupName; // Geneva moniker name std::string AccountMonikerName; // Mapped moniker name // This map stores all EventHub Keys. // map key: Event Hub name. In GCS, each name is a hard-coded name for // different scenario: // "raw" -> Event Hub notification for CentralBond store type., // "error" -> Top N service. // "distributedtracing" -> Distributed tracing service. // "eventpublisher" -> Event Hub data publisher. std::unordered_map EventHubKeys; bool IsValid() const; using parser_type = details::JsonObjectParser; static std::unordered_map> ParserMap; }; std::ostream& operator<<(std::ostream & os, const ServiceBusAccountKey& obj); struct StorageSasKey { std::string ResourceName; std::string SasKey; std::string SasKeyType; bool IsValid() const { return !ResourceName.empty() && !SasKey.empty() && !SasKeyType.empty(); } using parser_type = details::JsonObjectParser; static std::unordered_map> ParserMap; }; std::ostream& operator<<(std::ostream & os, const StorageSasKey& obj); struct StorageAccountKey { std::string StorageAccountName; std::string AccountGroupName; std::string AccountMonikerName; std::string BlobEndpoint; std::string QueueEndpoint; std::string TableEndpoint; std::vector SasKeys; bool IsValid() const; using parser_type = details::JsonObjectParser; static std::unordered_map> ParserMap; }; std::ostream& operator<<(std::ostream & os, const StorageAccountKey& obj); // GcsAccount contains GCS account data. Its tagId should never be empty. // Its other values can be of two kinds: // 1) none of the values are empty. // 2) all the values are empty. struct GcsAccount { std::vector MaSigningPublicKeys; std::string SasKeysExpireTimeUtc; std::vector ServiceBusAccountKeys; std::vector StorageAccountKeys; std::string TagId; bool IsValid() const; // Return true if all values (ignoring TagId) are empty; return false otherwise. bool IsEmpty() const; static std::unordered_map> ParserMap; }; std::ostream& operator<<(std::ostream & os, const GcsAccount& obj); } #endif // __GCSJSONDATA_HH__ ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsJsonParser.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include "GcsJsonParser.hh" #include "GcsUtil.hh" #include "Logger.hh" #include "Trace.hh" #include "GcsJsonData.hh" using namespace mdsd; using namespace mdsd::details; bool GcsJsonParser::Parse( GcsAccount & gcsAccount ) { Trace trace(Trace::MdsCmd, "GcsJsonParser::Parse"); if (!m_jsonStr.empty()) { try { m_jsonObj = web::json::value::parse(m_jsonStr); } catch(const std::exception & ex) { Logger::LogError("Error: failed to parse JSON string '" + m_jsonStr + "': " + ex.what()); return false; } } if (!m_jsonObj.is_null()) { try { JsonObjectParser rootParser("", m_jsonObj); rootParser.Parse(gcsAccount); if (trace.IsActive()) { std::ostringstream ostr; ostr << gcsAccount; TRACEINFO(trace, ostr.str()); } } catch(const std::exception & ex) { Logger::LogError(std::string("Error: failed to parse JSON object: ") + ex.what()); return false; } } return true; } void GcsJsonBaseParser::CheckType() const { GcsUtil::ThrowIfInvalidType(GetPath(), GetExpectedType(), GetActualType()); } void GcsJsonBaseParser::LogMsgIfUnrecognized( const std::string & itemname ) const { std::ostringstream msg; msg << "Ignore unrecognized item: '" << itemname << "'"; // Because future GCS may add additional JSON key/value pairs, only log unrecognized // name as information only. Logger::LogInfo(msg.str()); } void EventHubKeysParser::Parse( std::unordered_map& ehkeymap ) { CheckType(); auto & jsonObj = GetJson().as_object(); for (auto iter = jsonObj.cbegin(); iter != jsonObj.cend(); ++iter) { const auto & name = iter->first; const auto & value = iter->second; if (ehkeymap.find(name) != ehkeymap.end()) { throw JsonParseException("Found duplicate item: " + GetPath() + "/" + name); } EventHubKey ehkey; JsonObjectParser ehkeyParser(GetPath() + "/" + name, value); ehkeyParser.Parse(ehkey); ehkeymap[name] = std::move(ehkey); } } void StringArrayParser::Parse( std::vector& resultList ) { CheckType(); auto & array = GetJson().as_array(); for (size_t i = 0; i < array.size(); i++) { auto jsontype = array.at(i).type(); if (web::json::value::String == jsontype) { resultList.push_back(array.at(i).as_string()); } else { throw JsonParseException("StringArrayParser: unsupported JSON type '" + GcsUtil::GetJsonTypeStr(jsontype) + "'"); } } } ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsJsonParser.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef _GCSJSONPARSER_HH__ #define _GCSJSONPARSER_HH__ #include #include #include #include namespace mdsd { struct EventHubKey; struct ServiceBusAccountKey; struct StorageSasKey; struct StorageAccountKey; struct GcsAccount; class GcsJsonParser { public: GcsJsonParser(const std::string & jsonStr) : m_jsonStr(jsonStr) {} GcsJsonParser(const web::json::value & jsonObj) : m_jsonObj(jsonObj) {} // // Parse JSON string or JSON object and store the results to gcsAccount object. // // To store parsed account information // Return true if parsing succeeds; return false if any error. bool Parse(GcsAccount & gcsAccount); private: std::string m_jsonStr; web::json::value m_jsonObj; }; namespace details { // This is the base class for all JSON parser classes. class GcsJsonBaseParser { public: // Constructor. // JSON path string. e.g. "/root/ServiceBusAccountKeys/EventHubKeys". // It is to locate items in JSON parsing. // JSON object to be parsed GcsJsonBaseParser( const std::string & path, const web::json::value & jsonObj ) : m_path(path), m_jsonObj(jsonObj) { } virtual ~GcsJsonBaseParser() = default; protected: // Get actual JSON value type web::json::value::value_type GetActualType() const { return m_jsonObj.type(); } // Get expected JSON value type virtual web::json::value::value_type GetExpectedType() const { return web::json::value::Object; } const web::json::value& GetJson() const { return m_jsonObj; } virtual std::string GetPath() const { return m_path; } // Get path assuming the object is an array type. virtual std::string GetArrayPath(size_t i) const { return GetPath() + "[" + std::to_string(i) + "]"; } bool IsNull() const { return m_jsonObj.is_null(); } // Validate whether the JSON object has expected type. Throw exception if not. void CheckType() const; // Log message if unrecognized JSON name is found in JSON string. void LogMsgIfUnrecognized(const std::string & itemname) const; private: std::string m_path; web::json::value m_jsonObj; }; class EventHubKeysParser : public GcsJsonBaseParser { public: EventHubKeysParser(const std::string & path, const web::json::value & jsonObj) : GcsJsonBaseParser(path, jsonObj) {} void Parse(std::unordered_map& ehkeys); }; // To parse an array of json strings class StringArrayParser : public GcsJsonBaseParser { public: StringArrayParser(const std::string & path, const web::json::value & jsonObj) : GcsJsonBaseParser(path, jsonObj) {} void Parse(std::vector& resultList); protected: web::json::value::value_type GetExpectedType() const override { return web::json::value::Array; } }; // A template to parse an array of json object type. // The object type 'T' must have 'parser_type' defined. template class ObjectArrayParser : public GcsJsonBaseParser { public: ObjectArrayParser(const std::string & path, const web::json::value & jsonObj) : GcsJsonBaseParser(path, jsonObj) {} void Parse(std::vector& resultList) { CheckType(); auto & array = GetJson().as_array(); for (size_t i = 0; i < array.size(); i++) { typename T::parser_type parser(GetArrayPath(i), array.at(i)); T item; parser.Parse(item); resultList.push_back(std::move(item)); } } protected: web::json::value::value_type GetExpectedType() const override { return web::json::value::Array; } }; // Parse a JSON object with type T template class JsonObjectParser : public GcsJsonBaseParser { public: JsonObjectParser(const std::string & path, const web::json::value & jsonObj) : GcsJsonBaseParser(path, jsonObj) {} void Parse(T& result) { CheckType(); auto & jsonObj = GetJson().as_object(); for (auto iter = jsonObj.cbegin(); iter != jsonObj.cend(); ++iter) { const auto & name = iter->first; const auto & value = iter->second; auto itempath = GetPath() + "/" + name; auto parserIter = T::ParserMap.find(name); if (parserIter == T::ParserMap.end()) { LogMsgIfUnrecognized(itempath); } else { parserIter->second(itempath, value, result); } } } }; } // namespace details } // namespace mdsd #endif // _GCSJSONPARSER_HH__ ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsServiceInfo.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include "GcsServiceInfo.hh" #include "Logger.hh" #include "MdsConst.hh" using namespace mdsd; GcsServiceInfo GcsConfig::s_gcsInfo; // Read an environment variable and store the value to 'value'. // If given environment variable is invalid, do nothing. static void GetEnvVar(const std::string & name, std::string& value) { if (name.empty()) { return; } char* v = std::getenv(name.c_str()); if (!v) { Logger::LogInfo("Environment variable '" + name + "' is not defined."); } else { value = v; } } void GcsConfig::ReadFromEnvVars() { GetEnvVar(gcs::c_GcsEnv_EndPoint, s_gcsInfo.EndPoint); GetEnvVar(gcs::c_GcsEnv_Environment, s_gcsInfo.Environment); GetEnvVar(gcs::c_GcsEnv_Account, s_gcsInfo.GenevaAccount); GetEnvVar(gcs::c_GcsEnv_Region, s_gcsInfo.Region); GetEnvVar(gcs::c_GcsEnv_ThumbPrint, s_gcsInfo.ThumbPrint); GetEnvVar(gcs::c_GcsEnv_CertFile, s_gcsInfo.CertFile); GetEnvVar(gcs::c_GcsEnv_KeyFile, s_gcsInfo.KeyFile); GetEnvVar(gcs::c_GcsEnv_SslDigest, s_gcsInfo.SslDigest); } bool GcsConfig::IsSet() { return ( !s_gcsInfo.EndPoint.empty() && !s_gcsInfo.Environment.empty() && !s_gcsInfo.GenevaAccount.empty() && !s_gcsInfo.Region.empty() && !s_gcsInfo.ThumbPrint.empty() && !s_gcsInfo.CertFile.empty() && !s_gcsInfo.KeyFile.empty() && !s_gcsInfo.SslDigest.empty() ); } ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsServiceInfo.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __GCSSERVICEINFO_HH__ #define __GCSSERVICEINFO_HH__ #include namespace mdsd { struct GcsServiceInfo { std::string EndPoint; std::string Environment; std::string GenevaAccount; std::string ConfigNamespace; std::string Region; std::string SpecifiedConfigVersion; std::string ActualConfigVersion; std::string ThumbPrint; std::string CertFile; std::string KeyFile; std::string SslDigest; }; class GcsConfig { static void ReadFromEnvVars(); // Return true if all required environmental variable settings are set // (may not be valid values). Return false otherwise. static bool IsSet(); static GcsServiceInfo& GetData() { return s_gcsInfo; } private: static GcsServiceInfo s_gcsInfo; }; } // namespace mdsd #endif // __GCSSERVICEINFO_HH__ ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsUtil.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include "GcsUtil.hh" namespace mdsd { namespace GcsUtil { static std::map& GetJsonTypeMap() { static std::map m = { { web::json::value::Number, "Number" }, { web::json::value::Boolean, "Boolean" }, { web::json::value::String, "String" }, { web::json::value::Object, "Object" }, { web::json::value::Array, "Array" }, { web::json::value::Null, "Null" } }; return m; } std::string GetJsonTypeStr(web::json::value::value_type t) { auto & m = GetJsonTypeMap(); auto item = m.find(t); if (item != m.end()) { return item->second; } return "Unknown"; } void ThrowIfInvalidType( const std::string & itemName, web::json::value::value_type expectedType, web::json::value::value_type actualType ) { if (expectedType != actualType) { std::ostringstream ostr; ostr << "Json item '" << itemName << "' has invalid type:" << " expected=" << GetJsonTypeStr(expectedType) << " actual=" << GetJsonTypeStr(actualType); throw JsonParseException(ostr.str()); } } // key: Gcs Environment. e.g. "Test" // value: Gcs endpoing. e.g. "ppe.warmpath.msftcloudes.com" static std::unordered_map& GetGcsEnvEndPointMap() { static std::unordered_map m = { {"DiagnosticsProd", "prod.warmpath.msftcloudes.com"}, {"FirstPartyProd", "prod.warmpath.msftcloudes.com"}, {"Test", "ppe.warmpath.msftcloudes.com"}, {"Stage", "ppe.warmpath.msftcloudes.com"}, {"BillingProd", "prod.warmpath.msftcloudes.com"}, {"ExternalProd", "prod.warmpath.msftcloudes.com"}, {"CaMooncake", "mooncake.warmpath.chinacloudapi.cn"}, {"CaBlackforest", "blackforest.warmpath.cloudapi.de"}, {"CaFairfax", "fairfax.warmpath.usgovcloudapi.net"} }; return m; } std::string GetGcsEndpointFromEnvironment( const std::string & gcsEnvName ) { auto & m = GetGcsEnvEndPointMap(); auto item = m.find(gcsEnvName); if (item == m.end()) { return std::string(); } return item->second; } } // namespace GcsUtil } // namespace mdsd ================================================ FILE: Diagnostic/mdsd/mdsrest/GcsUtil.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __GCSUTIL_HH__ #define __GCSUTIL_HH__ #include #include #include "MdsRestException.hh" namespace mdsd { namespace GcsUtil { // // Get a string format of a JSON value type. // std::string GetJsonTypeStr(web::json::value::value_type t); // // Throw JsonParseException if actual type is not equal to expected type // for an item with name called itemName. // void ThrowIfInvalidType(const std::string & itemName, web::json::value::value_type expectedType, web::json::value::value_type actualType); // // Get GCS service endpoint given GCS environment value (e.g. "Test") // This function is used when GCS environment is defined but GCS endpoint // is empty. This can avoid customer to remember the exact endpoint. // Customer can still define endpoint if needed. // std::string GetGcsEndpointFromEnvironment(const std::string & gcsEnvName); } // namespace GcsUtil } // namespace mdsd #endif // __GCSUTIL_HH__ ================================================ FILE: Diagnostic/mdsd/mdsrest/MdsConst.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSCONST_HH__ #define __MDSCONST_HH__ #include namespace mdsd { namespace gcs { const std::string c_GcsServiceName = "/api/agent/v2/"; const std::string c_GcsMonitoringStorageKeysApiName = "MonitoringStorageKeys"; const int c_HttpTimeInSeconds = 60; const std::string c_RequestIdHeader = "-request-id:"; const std::string c_GcsEnv_EndPoint = "MONITORING_GCS_ENDPOINT"; const std::string c_GcsEnv_Environment = "MONITORING_GCS_ENVIRONMENT"; const std::string c_GcsEnv_Account = "MONITORING_GCS_ACCOUNT"; const std::string c_GcsEnv_Namespace = "MONITORING_GCS_NAMESPACE"; const std::string c_GcsEnv_Region = "MONITORING_GCS_REGION"; const std::string c_GcsEnv_ConfigVersion = "MONITORING_CONFIG_VERSION"; const std::string c_GcsEnv_ThumbPrint = "MONITORING_GCS_THUMBPRINT"; const std::string c_GcsEnv_CertFile = "MONITORING_GCS_CERT_CertFile"; const std::string c_GcsEnv_KeyFile = "MONITORING_GCS_CERT_KeyFile"; const std::string c_GcsEnv_SslDigest = "MONITORING_GCS_CERT_SSLDIGEST"; const std::string c_EventHub_notice = "raw"; const std::string c_EventHub_publish = "eventpublisher"; } } #endif ================================================ FILE: Diagnostic/mdsd/mdsrest/MdsRest.cc ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include #include #include #include #include #include #include #include "GcsJsonData.hh" #include "GcsJsonParser.hh" #include "GcsUtil.hh" #include "Logger.hh" #include "MdsRest.hh" #include "MdsConst.hh" #include "OpensslCert.hh" #include "OpensslCertStore.hh" #include "Trace.hh" using namespace mdsd; using namespace web::http; using namespace web::http::client; static inline void ThrowIfEmpty( const std::string & apiName, const std::string & argName, const std::string & argVal ) { if (argVal.empty()) { throw std::invalid_argument(apiName + ": unexpected empty string for " + argName); } } MdsRestInterface::MdsRestInterface( const std::string & endPoint, const std::string & gcsEnvironment, const std::string & thumbPrint, const std::string & certFile, const std::string & keyFile, const std::string & sslDigest ) : m_endPoint(endPoint), m_gcsEnv(gcsEnvironment), m_thumbPrint(thumbPrint), m_certFile(certFile), m_keyFile(keyFile), m_sslDigest(sslDigest) { ThrowIfEmpty("MdsRestInterface", "gcsEnvironment", gcsEnvironment); ThrowIfEmpty("MdsRestInterface", "thumbPrint", thumbPrint); ThrowIfEmpty("MdsRestInterface", "certFile", certFile); ThrowIfEmpty("MdsRestInterface", "keyFile", keyFile); ThrowIfEmpty("MdsRestInterface", "sslDigest", sslDigest); } bool MdsRestInterface::Initialize() { Trace trace(Trace::MdsCmd, "MdsRestInterface::Initialize"); try { if (m_endPoint.empty()) { m_endPoint = GcsUtil::GetGcsEndpointFromEnvironment(m_gcsEnv); if (m_endPoint.empty()) { Logger::LogError("Error: unexpected empty value for GCS endpoint."); return false; } } m_initialized = InitCert(); } catch(const std::exception & ex) { Logger::LogError(std::string("Error: MdsRestInterface::Initialize() exception: ") + ex.what()); m_initialized = false; } return m_initialized; } bool MdsRestInterface::InitCert() { Trace trace(Trace::MdsCmd, "MdsRestInterface::InitCert"); bool retVal = true; try { OpensslCertStore certStore(m_certFile, m_keyFile, m_sslDigest); m_cert = certStore.LoadCertificate(m_thumbPrint); if (!m_cert->IsValid()) { Logger::LogError("Error: initializing certificate failed: certificate is invalid"); retVal = false; m_cert = nullptr; } } catch(const std::exception& ex) { Logger::LogError(std::string("Error: initializing certificate failed: ") + ex.what()); retVal = false; } return retVal; } void MdsRestInterface::ResetClient() { Trace trace(Trace::MdsCmd, "MdsRestInterface::ResetClient"); if (m_client) { TRACEINFO(trace, "Http client will be reset due to previous failure."); m_client.reset(); m_resetHttpClient = false; } http_client_config httpClientConfig; httpClientConfig.set_validate_certificates(true); httpClientConfig.set_timeout(utility::seconds(gcs::c_HttpTimeInSeconds)); httpClientConfig.set_nativehandle_options([this](web::http::client::native_handle handle)->void { SetNativeHandleOptions(handle); }); auto fullEndpoint = "https://" + m_endPoint; m_client = std::move(std::unique_ptr(new http_client(fullEndpoint.c_str(), httpClientConfig))); } pplx::task MdsRestInterface::QueryGcsAccountInfo( const std::string & mdsAccount, const std::string & mdsNamespace, const std::string & configVersion, const std::string & region, const std::string & agentIdentity, const std::string & tagId ) { Trace trace(Trace::MdsCmd, "MdsRestInterface::QueryGcsAccountInfo"); ThrowIfEmpty("GcsAccountInfo", "mdsAccount", mdsAccount); ThrowIfEmpty("GcsAccountInfo", "mdsNamespace", mdsNamespace); ThrowIfEmpty("GcsAccountInfo", "configVersion", configVersion); ThrowIfEmpty("GcsAccountInfo", "region", region); ThrowIfEmpty("GcsAccountInfo", "agentIdentity", agentIdentity); if (!m_initialized) { if (!Initialize()) { return pplx::task_from_result(false); } } try { auto apicall = BuildGcsApiCall(mdsAccount); auto args = BuildGcsAcountArgs(mdsNamespace, configVersion, region, agentIdentity, tagId); return ExecuteGcsGetCall(apicall, args); } catch(const std::exception & ex) { Logger::LogError(std::string("Error: QueryGcsAccountInfo() exception: ") + ex.what()); } return pplx::task_from_result(false); } std::string MdsRestInterface::BuildGcsApiCall( const std::string & mdsAccount ) { std::ostringstream apicall; apicall << gcs::c_GcsServiceName << m_gcsEnv << "/" << mdsAccount << "/" << gcs::c_GcsMonitoringStorageKeysApiName << "/"; return apicall.str(); } std::string MdsRestInterface::BuildGcsAcountArgs( const std::string & mdsNamespace, const std::string & configVersion, const std::string & region, const std::string & agentIdentity, const std::string & tagId ) { // Encode agentIdentity so that no special character like '/' is used in URI. std::vector vec(agentIdentity.begin(), agentIdentity.end()); auto encodedAgentId = utility::conversions::to_base64(vec); std::ostringstream args; args << "Namespace=" << mdsNamespace << "&ConfigMajorVersion=" << configVersion << "&Region=" << region << "&Identity=" << encodedAgentId; if (!tagId.empty()) { args << "&TagId=" + tagId; } return args.str(); } pplx::task MdsRestInterface::ExecuteGcsGetCall( const std::string & contractApi, const std::string & arguments ) { Trace trace(Trace::MdsCmd, "MdsRestInterface::ExecuteGcsGetCall"); TRACEINFO(trace, "contractApi='" << contractApi << "'; arguments='" << arguments << "'"); ThrowIfEmpty("ExecuteGcsGetCall", "contractApi", contractApi); ThrowIfEmpty("ExecuteGcsGetCall", "arguments", arguments); if (!m_client || m_resetHttpClient) { ResetClient(); } web::http::uri_builder request_uri; request_uri.append_path(contractApi, false); request_uri.append_query(arguments, true); http_request request; auto requestId = utility::uuid_to_string(utility::new_uuid()); request.headers().add(_XPLATSTR("x-ms-client-request-id"), requestId.c_str()); request.set_request_uri(request_uri.to_uri()); request.set_method(methods::GET); auto shThis = shared_from_this(); TRACEINFO(trace, "Start to send request {" << requestId << "} to GCS: " << request.absolute_uri().to_string()); return m_client->request(request) .then([shThis](pplx::task task) { return shThis->HandleServerResponse(task); }); TRACEINFO(trace, "ExecuteGcsGetCall returns false"); return pplx::task_from_result(false); } void MdsRestInterface::SetNativeHandleOptions( web::http::client::native_handle handle ) { Trace trace(Trace::MdsCmd, "MdsRestInterface::SetNativeHandleOptions"); auto streamobj = static_cast* >(handle); if (!streamobj) { throw std::runtime_error("SetNativeHandleOptions() failed: unexpected NULL tcp::socket handle"); } auto ssl = streamobj->native_handle(); if (!ssl) { throw std::runtime_error("SetNativeHandleOptions() failed: unexpected NULL ssl handle"); } const int isOK = 1; auto errorcode = ::SSL_use_certificate(ssl, m_cert->GetCert()); if (isOK != errorcode) { throw std::runtime_error("SSL_use_certificate() failed with error " + std::to_string(errorcode)); } errorcode = ::SSL_use_PrivateKey(ssl, m_cert->GetPrivateKey()); if (isOK != errorcode) { throw std::runtime_error("SSL_use_PrivateKey() failed with error " + std::to_string(errorcode)); } // Disable weak ssl ciphers const std::string cipherList = "HIGH:!DSS:!RC4:!aNULL@STRENGTH"; errorcode = ::SSL_set_cipher_list(ssl, cipherList.c_str()); if (isOK != errorcode) { throw std::runtime_error("SSL_set_cipher_list() failed with error " + std::to_string(errorcode)); } } std::string MdsRestInterface::GetRequestIdFromResponse( const std::string & responseString ) { Trace trace(Trace::MdsCmd, "MdsRestInterface::GetRequestIdFromResponse"); if (responseString.empty()) { TRACEINFO(trace, "ResponseString is empty. No request id is found."); return std::string(); } auto ptr = responseString.find(mdsd::gcs::c_RequestIdHeader); if (ptr == std::string::npos) { TRACEINFO(trace, "No request id is found from response string."); return std::string(); } ptr += mdsd::gcs::c_RequestIdHeader.size(); auto index = responseString.find_first_not_of(' ', ptr); std::string requestId; while(isalnum(responseString[index]) || responseString[index] == '-') { requestId.append(1, responseString[index]); index++; } TRACEINFO(trace, "RequestId from response: '" << requestId << "'"); return requestId; } static inline bool IsHttpStatusOK(web::http::status_code statusCode) { return (status_codes::OK == statusCode || status_codes::Created == statusCode); // 201. According to MSDN, 201 means success. } bool MdsRestInterface::HandleServerResponse( pplx::task responseTask ) { Trace trace(Trace::MdsCmd, "MdsRestInterface::HandleServerResponse"); bool retVal = false; try { auto response = responseTask.get(); auto statusCode = response.status_code(); auto responseString = response.to_string(); if (trace.IsActive()) { TRACEINFO(trace, "Response Code: " << statusCode << "; Response: " << responseString); } if (!IsHttpStatusOK(statusCode)) { auto requestId = GetRequestIdFromResponse(responseString); std::ostringstream ostr; ostr << "Error: request to Geneva failed with status code=" << statusCode << "; requestId=" << requestId << "; Response: " << responseString; Logger::LogError(ostr.str()); // Only reset http client when the GCS service is not available and need reconnect later. if (status_codes::ServiceUnavailable == statusCode) { m_resetHttpClient = true; } } else { m_responseJsonVal = response.extract_json().get(); // As long as the json object has the expected type, it is OK for http request. // Detailed data and validation need to be parsed from this json object. if (web::json::value::Object == m_responseJsonVal.type()) { retVal = true; } else { auto requestId = GetRequestIdFromResponse(responseString); auto jsonType = m_responseJsonVal.type(); auto jsonTypeStr = mdsd::GcsUtil::GetJsonTypeStr(jsonType); std::ostringstream ostr; ostr << "Error: received response, but an unexpected result was returned; " << "expected a JSON object, but received type " << jsonType << " " << jsonTypeStr << "; requestId=" << requestId; Logger::LogError(ostr.str()); } } } catch(const std::exception & ex) { Logger::LogError(std::string("Error: request failed with exception: ") + ex.what()); m_resetHttpClient = true; } TRACEINFO(trace, "HandleServerResponse returned " << (retVal? "true" : "false")); return retVal; } bool MdsRestInterface::GetGcsAccountData(GcsAccount & gcsAccount) const { Trace trace(Trace::MdsCmd, "MdsRestInterface::GetGcsAccountData()"); if (m_responseJsonVal.is_null()) { TRACEINFO(trace, "GCS account JSON object is null."); return false; } else { GcsJsonParser parser(m_responseJsonVal); return parser.Parse(gcsAccount); } } ================================================ FILE: Diagnostic/mdsd/mdsrest/MdsRest.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSREST_HH__ #define __MDSREST_HH__ #include #include #include #include class OpensslCert; namespace mdsd { struct GcsAccount; /// This class defines APIs to call Geneva Configuration Service (GCS) REST APIs. /// NOTE: /// - This class is not thread-safe. class MdsRestInterface : public std::enable_shared_from_this { public: /// Construct a new MdsRestInterface. /// GCS endpoint. If empty, search its value /// using gcsEnvironment from pre-defined table. e.g. "ppe.warmpath.msftcloudes.com" /// Environment. e.g. "Test" /// Certificate thumb print /// full path to public certificate file. /// full path to private key file. /// certificate digest. e.g. "sha1" static std::shared_ptr Create( const std::string & endPoint, const std::string & gcsEnvironment, const std::string & thumbPrint, const std::string & certFile, const std::string & keyFile, const std::string & sslDigest ) { // Because the MdsRestInterface constructor is private, std::make_shared cannot be used. // std::make_shared requires public constructor. return std::shared_ptr( new MdsRestInterface(endPoint, gcsEnvironment, thumbPrint, certFile, keyFile, sslDigest)); } ~MdsRestInterface() = default; /// Initialize MdsRestInterface. /// Return true if success, false if any error. bool Initialize(); /// Query GCS account information. If successful, the result will be stored to /// json object m_responseJsonVal. /// /// MDS Account name /// MDS namespace /// configuration version. e.g. "Ver5v0" /// Region to get storage account credentials. e.g. "westus" /// An identification string, which is used for /// http query hashing. It can be built from mdsd IdentityColumns. /// GCS configuration tag id. GCS internally has a tag id, which /// is a combination of service configuration file md5 hash + account moniker versions. /// If the input tagId is equal to GCS's internal tag id, GCS will return null JSON objects. /// If the input tagId is not equal to GCS's internal tag id, GCS will return full /// account information. GCS account query will return its internal tagId in the returned JSON. /// /// Return true if success; return false if any error. pplx::task QueryGcsAccountInfo( const std::string & mdsAccount, const std::string & mdsNamespace, const std::string & configVersion, const std::string & region, const std::string & agentIdentity, const std::string & tagId); /// Get the account JSON object, which stores results from GCS account query. web::json::value GetGcsAccountJson() const { return m_responseJsonVal; } /// Parse GCS account JSON object and return the results in 'gcsAccount'. /// Return true if JSON object is successfully parsed. /// Return false if JSON object is null, or there is parsing error. bool GetGcsAccountData(GcsAccount & gcsAccount) const; private: /// Constructor. MdsRestInterface( const std::string & endPoint, const std::string & gcsEnvironment, const std::string & thumbPrint, const std::string & certFile, const std::string & keyFile, const std::string & sslDigest); /// Load certificates from files. /// Return true if success, false if any error. bool InitCert(); /// Reset http client if any. Then recreate it. void ResetClient(); /// Build the api string to call GCS service. std::string BuildGcsApiCall(const std::string & mdsAccount); /// Build the args to call GCS account service. std::string BuildGcsAcountArgs( const std::string & mdsNamespace, const std::string & configVersion, const std::string & region, const std::string & agentIdentity, const std::string & tagId); /// Execute GCS REST API call. /// Return true if success, false if any error. pplx::task ExecuteGcsGetCall(const std::string & contractApi, const std::string & arguments); /// Set certificates on native openssl handle void SetNativeHandleOptions(web::http::client::native_handle handle); /// Get http request id from http response. This is for logging purpose. std::string GetRequestIdFromResponse(const std::string & responseString); /// Handle GCS http response. Extract desired data from the response. /// Return true if success, false if any error. bool HandleServerResponse(pplx::task responseTask); private: bool m_initialized = false; std::string m_endPoint; std::string m_gcsEnv; std::string m_thumbPrint; std::string m_certFile; std::string m_keyFile; std::string m_sslDigest; std::shared_ptr m_cert; std::unique_ptr m_client; bool m_resetHttpClient = false; web::json::value m_responseJsonVal; }; } // namespace mdsd #endif // __MDSREST_HH__ ================================================ FILE: Diagnostic/mdsd/mdsrest/MdsRestException.hh ================================================ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #pragma once #ifndef __MDSRESTEXCEPTION__HH__ #define __MDSRESTEXCEPTION__HH__ #include #include namespace mdsd { class JsonParseException : public std::exception { private: std::string m_msg; public: JsonParseException(std::string message) noexcept : std::exception(), m_msg(std::move(message)) {} virtual const char * what() const noexcept { return m_msg.c_str(); } }; } #endif // __MDSRESTEXCEPTION__HH__ ================================================ FILE: Diagnostic/mdsd/parseglibc.py ================================================ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. # This script is to parse glibc-based binary files and print out # symbols whose GLIBC versions are higher than given version. # Report error if any such symbol is found. import argparse import glob import os import sys import time totalErrors = 0 def LogError(msg): global totalErrors totalErrors = totalErrors + 1 msg2 = "%s: Error: %s" % (sys.argv[0], msg) print msg2 def LogInfo(msg): print msg def ParseCmdLine(): parser = argparse.ArgumentParser(sys.argv[0]) parser.add_argument("-d", "--dir", type=str, required=False, help="directory where all its files are parsed.") parser.add_argument("-f", "--filepath", type=str, required=False, help="binary filepath.") parser.add_argument("-v", "--glibcver", type=str, required=True, help="max GLIBC ver. ex: 2.14") args = parser.parse_args() if not args.dir and not args.filepath: LogError("either '-d' or '-f' is required.") return args def GetFilesToParse(filepath, dirname): files = [] if filepath: if not os.path.isfile(filepath): LogError("%s is not a regular file." % (filepath)) else: files.append(filepath) elif dirname: if not os.path.isdir(dirname): LogError("%s is not a directory." % (dirname)) else: files = GetAllFiles(dirname) return files # Get all files in a directory. This doesn't include subdirectories and symbolic links. def GetAllFiles(dirname): if not dirname: return [] filepattern = dirname + "/*" filedirs = glob.glob(filepattern) files = [] for f in filedirs: if os.path.isfile(f) and (not os.path.islink(f)): files.append(f) return files # Get symbol file by running 'nm' def GetSymbols(filepath): outputfile = "testfile-" + str(time.time()) + ".txt" cmdline = "nm " + filepath + " 1>" + outputfile + " 2>&1" errCode = os.system(cmdline) if errCode != 0: LogError("cmd: '%s' failed with error %d" % (cmdline, errCode)) return "" return outputfile # Parse symbol file created by 'nm' def ParseSymbols(symbolfile, glibcver): with open(symbolfile, "r") as fh: lines = fh.readlines() for line in lines: if "@GLIBC_" in line: line = line.strip() ParseLine(line, glibcver) # libstdc++ should be statically linked starting from version 1.4 for unexpected_symbol in ["GLIBCXX", "CXXABI"]: if unexpected_symbol in line: LogError("Unexpected symbol {0}".format(unexpected_symbol)) # Parse one line to check for higher GLIBC version. # Report error if found def ParseLine(line, glibcver): global totalErrors items = line.split("GLIBC_") if len(items) != 2: LogError("unexpected symbol: %s" % (line)) else: if CompareVer(items[1], glibcver): totalErrors = totalErrors + 1 LogInfo(line) # Return True if ver1 > ver2. # Return False otherwise. def CompareVer(ver1, ver2): v1list = ver1.split(".") v2list = ver2.split(".") n = min(len(v1list), len(v2list)) for i in range(n): x = int(v1list[i]) y = int(v2list[i]) if x > y: return True elif x < y: return False if len(v1list) > len(v2list): return True return False def RunTest(filepath, dirname, glibcver): LogInfo("Parse GLIBC versions ...") files = GetFilesToParse(filepath, dirname) if len(files) == 0: LogError("no file to parse. Abort.") return for binfile in files: LogInfo("\nStart to parse file '%s' ..." % (binfile)) symbolfile = GetSymbols(binfile) if symbolfile: ParseSymbols(symbolfile, glibcver) os.remove(symbolfile) if totalErrors == 0: LogInfo("\nNo error is found. Test passed successfully.") else: LogInfo("\nTest failed. Total errors found: %d" % (totalErrors)) if __name__ == "__main__": args = ParseCmdLine() RunTest(args.filepath, args.dir, args.glibcver) sys.exit(totalErrors) ================================================ FILE: Diagnostic/mocks/Readme.txt ================================================ These three modules contain minimal mocks to allow the waagent code to load up on a non-Unix (e.g. windows) platform. They're just enough to allow the import statements to be executed; if you try to actually exercise the waagent functionality that relies upon them, you won't be happy. In order to make these visible in the correct way, you'll need to add the full path of this directory to the PYTHONPATH environment variable. Obviously, you shouldn't do this on Unix systems (including Linux and FreeBSD); the real modules are visible already, and you don't need these mocks. ================================================ FILE: Diagnostic/mocks/__init__.py ================================================ ================================================ FILE: Diagnostic/mocks/crypt.py ================================================ def crypt(password, salt): pass ================================================ FILE: Diagnostic/mocks/fcntl.py ================================================ def ioctl(fileid, ioctl_num, arg): pass ================================================ FILE: Diagnostic/mocks/pwd.py ================================================ def getpwnam(name): pass ================================================ FILE: Diagnostic/run_unittests.sh ================================================ #!/bin/bash for test in watchertests test_commonActions test_lad_logging_config test_lad_config_all test_LadDiagnosticUtil \ test_builtin test_lad_ext_settings; do python -m tests.$test done ================================================ FILE: Diagnostic/services/mdsd-lde.service ================================================ [Unit] Description=Azure Linux Diagnostic Extension After=network-online.target walinuxagent.service Wants=network-online.target walinuxagent.service ConditionFileIsExecutable={WORKDIR}/diagnostic.py [Service] Type=simple WorkingDirectory={WORKDIR}/ ExecStart=/usr/bin/python2 {WORKDIR}/diagnostic.py -daemon Restart=on-failure TimeoutSec=60 RestartSec=30 StartLimitBurst=10 StartLimitInterval=3600 [Install] WantedBy=multi-user.target ================================================ FILE: Diagnostic/services/metrics-extension.service ================================================ [Unit] Description=Metrics Extension service for Linux Agent metrics sourcing After=network.target [Service] ExecStart=%ME_BIN% -TokenSource MSI -Input influxdb_udp -InfluxDbHost 127.0.0.1 -InfluxDbUdpPort %ME_INFLUX_PORT% -DataDirectory %ME_DATA_DIRECTORY% -LocalControlChannel -MonitoringAccount %ME_MONITORING_ACCOUNT% -LogLevel Error ExecReload=/bin/kill -HUP $MAINPID Restart=on-failure RestartForceExitStatus=SIGPIPE KillMode=control-group [Install] WantedBy=multi-user.target ================================================ FILE: Diagnostic/services/metrics-sourcer.service ================================================ [Unit] Documentation=https://github.com/influxdata/telegraf/blob/master/README.md Description=Custom Modified Telegraf service for Linux Agent metrics sourcing After=network.target [Service] ExecStart=%TELEGRAF_BIN% --config %TELEGRAF_AGENT_CONFIG% --config-directory %TELEGRAF_CONFIG_DIR% ExecReload=/bin/kill -HUP $MAINPID Restart=on-failure RestartForceExitStatus=SIGPIPE KillMode=control-group [Install] WantedBy=multi-user.target ================================================ FILE: Diagnostic/shim.sh ================================================ #!/usr/bin/env bash # This is the main driver file for LAD extension. This file first checks if Python 2 is available on the VM and exits early if not # Control arguments passed to the shim are redirected to diagnostic.py without validation. COMMAND="./diagnostic.py" PYTHON="" ARG="$@" function find_python() { local python_exec_command=$1 if command -v python2 >/dev/null 2>&1 ; then eval ${python_exec_command}="python2" fi } find_python PYTHON if [ -z "$PYTHON" ] # If python2 is not installed, we will fail the install with the following error, requiring cx to have python pre-installed then echo "No Python 2 interpreter found, which is an LAD extension dependency. Please install Python 2 before retrying LAD extension deployment." >&2 exit 52 # Missing Dependency else ${PYTHON} --version 2>&1 fi ${PYTHON} ${COMMAND} ${ARG} exit $? ================================================ FILE: Diagnostic/tests/.gitignore ================================================ lad_2_3_metric_definitions_sample.json ================================================ FILE: Diagnostic/tests/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: Diagnostic/tests/lad_2_3_compatible_portal_pub_settings.json ================================================ { "StorageAccount": "__DIAGNOSTIC_STORAGE_ACCOUNT__", "ladCfg": { "diagnosticMonitorConfiguration": { "eventVolume": "Medium", "metrics": { "metricAggregation": [ { "scheduledTransferPeriod": "PT1H" }, { "scheduledTransferPeriod": "PT1M" } ], "resourceId": "__VM_RESOURCE_ID__" }, "performanceCounters": { "performanceCounterConfiguration": [ { "annotation": [ { "displayName": "Disk read guest OS", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "readbytespersecond", "counterSpecifier": "/builtin/disk/readbytespersecond", "type": "builtin", "unit": "BytesPerSecond" }, { "annotation": [ { "displayName": "Disk writes", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "writespersecond", "counterSpecifier": "/builtin/disk/writespersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Disk transfer time", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "averagetransfertime", "counterSpecifier": "/builtin/disk/averagetransfertime", "type": "builtin", "unit": "Seconds" }, { "annotation": [ { "displayName": "Disk transfers", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "transferspersecond", "counterSpecifier": "/builtin/disk/transferspersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Disk write guest OS", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "writebytespersecond", "counterSpecifier": "/builtin/disk/writebytespersecond", "type": "builtin", "unit": "BytesPerSecond" }, { "annotation": [ { "displayName": "Disk read time", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "averagereadtime", "counterSpecifier": "/builtin/disk/averagereadtime", "type": "builtin", "unit": "Seconds" }, { "annotation": [ { "displayName": "Disk write time", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "averagewritetime", "counterSpecifier": "/builtin/disk/averagewritetime", "type": "builtin", "unit": "Seconds" }, { "annotation": [ { "displayName": "Disk total bytes", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "bytespersecond", "counterSpecifier": "/builtin/disk/bytespersecond", "type": "builtin", "unit": "BytesPerSecond" }, { "annotation": [ { "displayName": "Disk reads", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "readspersecond", "counterSpecifier": "/builtin/disk/readspersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Disk queue length", "locale": "en-us" } ], "class": "disk", "condition": "IsAggregate=TRUE", "counter": "averagediskqueuelength", "counterSpecifier": "/builtin/disk/averagediskqueuelength", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Network in guest OS", "locale": "en-us" } ], "class": "network", "counter": "bytesreceived", "counterSpecifier": "/builtin/network/bytesreceived", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Network total bytes", "locale": "en-us" } ], "class": "network", "counter": "bytestotal", "counterSpecifier": "/builtin/network/bytestotal", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Network out guest OS", "locale": "en-us" } ], "class": "network", "counter": "bytestransmitted", "counterSpecifier": "/builtin/network/bytestransmitted", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Network collisions", "locale": "en-us" } ], "class": "network", "counter": "totalcollisions", "counterSpecifier": "/builtin/network/totalcollisions", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Packets received errors", "locale": "en-us" } ], "class": "network", "counter": "totalrxerrors", "counterSpecifier": "/builtin/network/totalrxerrors", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Packets sent", "locale": "en-us" } ], "class": "network", "counter": "packetstransmitted", "counterSpecifier": "/builtin/network/packetstransmitted", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Packets received", "locale": "en-us" } ], "class": "network", "counter": "packetsreceived", "counterSpecifier": "/builtin/network/packetsreceived", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Packets sent errors", "locale": "en-us" } ], "class": "network", "counter": "totaltxerrors", "counterSpecifier": "/builtin/network/totaltxerrors", "type": "builtin", "unit": "Count" }, { "annotation": [ { "displayName": "Filesystem transfers/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "transferspersecond", "counterSpecifier": "/builtin/filesystem/transferspersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Filesystem % free space", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "percentfreespace", "counterSpecifier": "/builtin/filesystem/percentfreespace", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Filesystem % used space", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "percentusedspace", "counterSpecifier": "/builtin/filesystem/percentusedspace", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Filesystem used space", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "usedspace", "counterSpecifier": "/builtin/filesystem/usedspace", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Filesystem read bytes/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "bytesreadpersecond", "counterSpecifier": "/builtin/filesystem/bytesreadpersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Filesystem free space", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "freespace", "counterSpecifier": "/builtin/filesystem/freespace", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Filesystem % free inodes", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "percentfreeinodes", "counterSpecifier": "/builtin/filesystem/percentfreeinodes", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Filesystem bytes/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "bytespersecond", "counterSpecifier": "/builtin/filesystem/bytespersecond", "type": "builtin", "unit": "BytesPerSecond" }, { "annotation": [ { "displayName": "Filesystem reads/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "readspersecond", "counterSpecifier": "/builtin/filesystem/readspersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Filesystem write bytes/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "byteswrittenpersecond", "counterSpecifier": "/builtin/filesystem/byteswrittenpersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Filesystem writes/sec", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "writespersecond", "counterSpecifier": "/builtin/filesystem/writespersecond", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Filesystem % used inodes", "locale": "en-us" } ], "class": "filesystem", "condition": "IsAggregate=TRUE", "counter": "percentusedinodes", "counterSpecifier": "/builtin/filesystem/percentusedinodes", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU IO wait time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentiowaittime", "counterSpecifier": "/builtin/processor/percentiowaittime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU user time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentusertime", "counterSpecifier": "/builtin/processor/percentusertime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU nice time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentnicetime", "counterSpecifier": "/builtin/processor/percentnicetime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU percentage guest OS", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentprocessortime", "counterSpecifier": "/builtin/processor/percentprocessortime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU interrupt time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentinterrupttime", "counterSpecifier": "/builtin/processor/percentinterrupttime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU idle time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentidletime", "counterSpecifier": "/builtin/processor/percentidletime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "CPU privileged time", "locale": "en-us" } ], "class": "processor", "condition": "IsAggregate=TRUE", "counter": "percentprivilegedtime", "counterSpecifier": "/builtin/processor/percentprivilegedtime", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Memory available", "locale": "en-us" } ], "class": "memory", "counter": "availablememory", "counterSpecifier": "/builtin/memory/availablememory", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Swap percent used", "locale": "en-us" } ], "class": "memory", "counter": "percentusedswap", "counterSpecifier": "/builtin/memory/percentusedswap", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Memory used", "locale": "en-us" } ], "class": "memory", "counter": "usedmemory", "counterSpecifier": "/builtin/memory/usedmemory", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Page reads", "locale": "en-us" } ], "class": "memory", "counter": "pagesreadpersec", "counterSpecifier": "/builtin/memory/pagesreadpersec", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Swap available", "locale": "en-us" } ], "class": "memory", "counter": "availableswap", "counterSpecifier": "/builtin/memory/availableswap", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Swap percent available", "locale": "en-us" } ], "class": "memory", "counter": "percentavailableswap", "counterSpecifier": "/builtin/memory/percentavailableswap", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Mem. percent available", "locale": "en-us" } ], "class": "memory", "counter": "percentavailablememory", "counterSpecifier": "/builtin/memory/percentavailablememory", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Pages", "locale": "en-us" } ], "class": "memory", "counter": "pagespersec", "counterSpecifier": "/builtin/memory/pagespersec", "type": "builtin", "unit": "CountPerSecond" }, { "annotation": [ { "displayName": "Swap used", "locale": "en-us" } ], "class": "memory", "counter": "usedswap", "counterSpecifier": "/builtin/memory/usedswap", "type": "builtin", "unit": "Bytes" }, { "annotation": [ { "displayName": "Memory percentage", "locale": "en-us" } ], "class": "memory", "counter": "percentusedmemory", "counterSpecifier": "/builtin/memory/percentusedmemory", "type": "builtin", "unit": "Percent" }, { "annotation": [ { "displayName": "Page writes", "locale": "en-us" } ], "class": "memory", "counter": "pageswrittenpersec", "counterSpecifier": "/builtin/memory/pageswrittenpersec", "type": "builtin", "unit": "CountPerSecond" } ] }, "syslogEvents": { "syslogEventConfiguration": { "LOG_AUTH": "LOG_DEBUG", "LOG_AUTHPRIV": "LOG_DEBUG", "LOG_CRON": "LOG_DEBUG", "LOG_DAEMON": "LOG_DEBUG", "LOG_FTP": "LOG_DEBUG", "LOG_KERN": "LOG_DEBUG", "LOG_LOCAL0": "LOG_DEBUG", "LOG_LOCAL1": "LOG_DEBUG", "LOG_LOCAL2": "LOG_DEBUG", "LOG_LOCAL3": "LOG_DEBUG", "LOG_LOCAL4": "LOG_DEBUG", "LOG_LOCAL5": "LOG_DEBUG", "LOG_LOCAL6": "LOG_DEBUG", "LOG_LOCAL7": "LOG_DEBUG", "LOG_LPR": "LOG_DEBUG", "LOG_MAIL": "LOG_DEBUG", "LOG_NEWS": "LOG_DEBUG", "LOG_SYSLOG": "LOG_DEBUG", "LOG_USER": "LOG_DEBUG", "LOG_UUCP": "LOG_DEBUG" } } }, "sampleRateInSeconds": 15 } } ================================================ FILE: Diagnostic/tests/test_LadDiagnosticUtil.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import unittest import Utils.LadDiagnosticUtil as LadUtil class TestGetDiagnosticsMonitorConfigurationElement(unittest.TestCase): def setUp(self): self.empty_config = {} self.bogus_config = {"foo": "bar"} self.missing_from_config = {"diagnosticMonitorConfiguration": {"foo": "bar"}} self.valid_config = \ { "diagnosticMonitorConfiguration": { "foo": "bar", "eventVolume": "Large", "sinksConfig": { "Sink": [ { "name": "sink1", "type": "EventHub", "sasURL": "https://sbnamespace.servicebus.windows.net/raw?sr=https%3a%2f%2fsb" "namespace.servicebus.windows.net%2fraw%2f&sig=SIGNATURE%3d" "&se=1804371161&skn=writer" } ] }, "metrics": { "resourceId": "/subscriptions/1111-2222-3333-4444/resourcegroups/RG1/compute/foo", "metricAggregation": [ {"scheduledTransferPeriod": "PT5M"}, {"scheduledTransferPeriod": "PT1H"}, ] }, "performanceCounters": { "sinks": "sink1", "performanceCounterConfiguration": [ { "type": "builtin", "class": "Processor", "counter": "PercentIdleTime", "counterSpecifier": "/builtin/Processor/PercentIdleTime", "condition": "IsAggregate=TRUE", "sampleRate": "PT15S", "unit": "Percent", "annotation": [ { "displayName": "Aggregate CPU %idle time", "locale": "en-us" } ] } ] }, "syslogEvents": { "sinks": "sink2", "syslogEventConfiguration": { "LOG_LOCAL1": "LOG_INFO", "LOG_MAIL": "LOG_FATAL" } } }, "sampleRateInSeconds": 60 } def test_empty_config(self): self.assertIsNone(LadUtil.getDiagnosticsMonitorConfigurationElement(self.empty_config, "dummy")) def test_bogus_config(self): self.assertIsNone(LadUtil.getDiagnosticsMonitorConfigurationElement(self.bogus_config, "dummy")) def test_entry_not_present(self): self.assertIsNone(LadUtil.getDiagnosticsMonitorConfigurationElement(self.missing_from_config, "dummy")) def test_entry_is_present(self): self.assertEqual(LadUtil.getDiagnosticsMonitorConfigurationElement(self.valid_config, "foo"), "bar") def test_getDefaultSampleRateFromLadCfg(self): self.assertEqual(LadUtil.getDefaultSampleRateFromLadCfg(self.valid_config), 60) def test_getEventVolumeFromLadCfg(self): self.assertEqual(LadUtil.getEventVolumeFromLadCfg(self.valid_config), "Large") def test_getAggregationPeriodsFromLadCfg(self): periods = LadUtil.getAggregationPeriodsFromLadCfg(self.valid_config) self.assertEqual(len(periods), 2) self.assertIn('PT5M', periods) self.assertIn('PT1H', periods) def test_getPerformanceCounterCfgFromLadCfg(self): definitions = LadUtil.getPerformanceCounterCfgFromLadCfg(self.valid_config) self.assertEqual(1, len(definitions)) metric = definitions[0] self.assertIn('counterSpecifier', metric) self.assertEqual('/builtin/Processor/PercentIdleTime', metric['counterSpecifier']) def test_getResourceIdFromLadCfg(self): self.assertIsNone(LadUtil.getResourceIdFromLadCfg(self.missing_from_config)) res_id = LadUtil.getResourceIdFromLadCfg(self.valid_config) self.assertIsNotNone(res_id) self.assertIn("1111-2222-3333-4444", res_id) def test_getFeatureWideSinksFromLadCfg(self): self.assertEqual(LadUtil.getFeatureWideSinksFromLadCfg(self.valid_config, 'syslogEvents'), ['sink2']) self.assertEqual(LadUtil.getFeatureWideSinksFromLadCfg(self.valid_config, 'performanceCounters'), ['sink1']) class TestSinkConfiguration(unittest.TestCase): def setUp(self): self.config = \ { "sink": [ { "name": "sink1", "type": "EventHub", "sasURL": "https://sbnamespace.servicebus.windows.net/raw?sr=https%3a%2f%2fsb" "namespace.servicebus.windows.net%2fraw%2f&sig=SIGNATURE%3d" "&se=1804371161&skn=writer" }, { "name": "sink2", "type": "JsonBlob" }, { "name": "sink3", "type": "EventHub", "sasURL": "https://sbnamespace2.servicebus.windows.net/raw?sr=https%3a%2f%2fsb" "namespace.servicebus.windows.net%2fraw%2f&sig=SIGNATURE%3d" "&se=99999999999&skn=writer" } ] } self.sink_config = LadUtil.SinkConfiguration() self.sink_config.insert_from_config(self.config) def test_insert_from_config(self): json_config = {} sinks = LadUtil.SinkConfiguration() msgs = sinks.insert_from_config(json_config) self.assertEqual(msgs, '') json_config = {'sink': [{'Name': 'bad case'}]} sinks = LadUtil.SinkConfiguration() msgs = sinks.insert_from_config(json_config) self.assertEqual(msgs, "Ignoring invalid sink definition {'Name': 'bad case'}") def test_get_all_sink_names(self): sinks = self.sink_config.get_all_sink_names() self.assertEqual(len(sinks), len(self.config["sink"])) self.assertIn("sink1", sinks) for sink in self.config["sink"]: self.assertIn(sink["name"], sinks) def helper_get_sink_by_name(self, name, type, sasURL=False): sink = self.sink_config.get_sink_by_name(name) self.assertIsNotNone(sink) self.assertEqual(sink['name'], name) self.assertEqual(sink['type'], type) if sasURL: self.assertIn('sasURL', sink) def test_get_sink_by_name(self): self.assertIsNone(self.sink_config.get_sink_by_name("BogusSink")) self.helper_get_sink_by_name('sink1', 'EventHub', True) self.helper_get_sink_by_name('sink2', 'JsonBlob') self.helper_get_sink_by_name('sink3', 'EventHub', True) def helper_get_sinks_by_type(self, type, names): sink_list = self.sink_config.get_sinks_by_type(type) self.assertEqual(len(sink_list), len(names)) # Ugly nested loops... Please suggest any better Pythonic code names_from_sink_list = [sink['name'] for sink in sink_list] for name in names: self.assertIn(name, names_from_sink_list) def test_get_sinks_by_type(self): sink_list = self.sink_config.get_sinks_by_type("Bogus") self.assertEqual(len(sink_list), 0) self.helper_get_sinks_by_type('EventHub', ['sink1', 'sink3']) self.helper_get_sinks_by_type('JsonBlob', ['sink2']) if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/tests/test_builtin.py ================================================ import unittest import Providers.Builtin as BProvider import Utils.ProviderUtil as ProvUtil from Utils.mdsd_xml_templates import entire_xml_cfg_tmpl import xml.etree.ElementTree as ET import json import re class TestBuiltinMetric(unittest.TestCase): def setUp(self): self.basic_valid = { "type": "builtin", "class": "Processor", "counter": "PercentIdleTime", "counterSpecifier": "/builtin/Processor/PercentIdleTime", "condition": 'IsAggregate=TRUE', "sampleRate": "PT30S", "unit": "Percent", "annotation": [ { "displayName": "Aggregate CPU %idle time", "locale": "en-us" } ] } self.mapped = { "type": "builtin", "class": "filesystem", "counter": "Freespace", "counterSpecifier": "/builtin/Filesystem/Freespace(/)", "condition": 'Name="/"', "unit": "Bytes", "annotation": [ { "displayName": "Free space on /", "locale": "en-us" } ] } def test_IsType(self): try: item = BProvider.BuiltinMetric(self.basic_valid) self.assertTrue(item.is_type('builtin')) except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) def test_Class(self): dupe = self.basic_valid.copy() del dupe['class'] self.assertRaises(ProvUtil.InvalidCounterSpecification, BProvider.BuiltinMetric, dupe) try: metric = BProvider.BuiltinMetric(self.basic_valid) self.assertEqual(metric.class_name(), 'processor') except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) def test_Counter(self): dupe = self.basic_valid.copy() del dupe['counter'] self.assertRaises(ProvUtil.InvalidCounterSpecification, BProvider.BuiltinMetric, dupe) try: metric = BProvider.BuiltinMetric(self.basic_valid) self.assertEqual(metric.counter_name(), 'PercentIdleTime') except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) try: metric = BProvider.BuiltinMetric(self.mapped) self.assertEqual(metric.counter_name(), 'FreeMegabytes') except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) def test_condition(self): dupe = self.basic_valid.copy() del dupe['condition'] try: metric = BProvider.BuiltinMetric(dupe) self.assertIsNone(metric.condition()) except Exception as ex: self.fail("BuiltinMetric Constructor (dupe) raised exception: {0}".format(ex)) try: metric = BProvider.BuiltinMetric(self.mapped) self.assertEqual(metric.condition(), 'Name="/"') except Exception as ex: self.fail("BuiltinMetric Constructor (self.mapped) raised exception: {0}".format(ex)) try: metric = BProvider.BuiltinMetric(self.basic_valid) self.assertEqual(metric.condition(), 'IsAggregate=TRUE') except Exception as ex: self.fail("BuiltinMetric Constructor (self.basic_valid) raised exception: {0}".format(ex)) def test_label(self): dupe = self.basic_valid.copy() del dupe['counterSpecifier'] self.assertRaises(ProvUtil.InvalidCounterSpecification, BProvider.BuiltinMetric, dupe) try: metric = BProvider.BuiltinMetric(self.basic_valid) self.assertEqual(metric.label(), '/builtin/Processor/PercentIdleTime') except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) def test_sample_rate(self): try: metric = BProvider.BuiltinMetric(self.basic_valid) self.assertEqual(metric.sample_rate(), 30) except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) dupe = self.basic_valid.copy() del dupe['sampleRate'] try: metric = BProvider.BuiltinMetric(dupe) self.assertEqual(metric.sample_rate(), 15) except Exception as ex: self.fail("BuiltinMetric Constructor raised exception: {0}".format(ex)) class TestMakeXML(unittest.TestCase): def setUp(self): self.base_xml = entire_xml_cfg_tmpl def test_two_and_two(self): specs = [ { "type": "builtin", "class": "Processor", "counter": "PercentIdleTime", "counterSpecifier": "/builtin/Processor/PercentIdleTime", "condition": "IsAggregate=TRUE", "sampleRate": "PT30S", }, { "type": "builtin", "class": "filesystem", "counter": "Freespace", "counterSpecifier": "/builtin/Filesystem/Freespace(/)", "condition": "Name='/'", }, { "type": "builtin", "class": "Processor", "counter": "PercentProcessorTime", "counterSpecifier": "/builtin/Processor/PercentProcessorTime", "condition": "IsAggregate=TRUE", "sampleRate": "PT30S", }, { "type": "builtin", "class": "filesystem", "counter": "Freespace", "counterSpecifier": "/builtin/Filesystem/Freespace(/mnt)", "condition": "Name=\"/mnt\"", }, ] sink_names = set() for spec in specs: try: sink = BProvider.AddMetric(spec) self.assertIsNotNone(sink) sink_names.add(sink) except Exception as ex: self.fail("AddMetric({0}) raised exception: {1}".format(spec, ex)) self.assertEqual(len(sink_names), 3) doc = ET.ElementTree(ET.fromstring(self.base_xml)) BProvider.UpdateXML(doc) # xml_string = ET.tostring(doc.getroot()) # print xml_string class Lad2_3CompatiblePortalPublicSettingsGenerator(unittest.TestCase): @unittest.skip("Lad2_3Compat test needs redesign to be useful outside of internal development environment") def test_lad_2_3_compatible_portal_public_settings(self): """ This is rather a utility function that attempts to generate a standard LAD 3.0 protected settings JSON string for the Azure Portal charts experience. Unit, displayName, and condition are inferred/auto-filled from a sample Azure Insights metric definitions JSON pulled from ACIS. """ pub_settings = { "StorageAccount": "__DIAGNOSTIC_STORAGE_ACCOUNT__", "ladCfg": { "sampleRateInSeconds": 15, "diagnosticMonitorConfiguration": { "eventVolume": "Medium", "metrics": { "metricAggregation": [ { "scheduledTransferPeriod": "PT1H" }, { "scheduledTransferPeriod": "PT1M" } ], "resourceId": "__VM_RESOURCE_ID__" }, "performanceCounters": { "performanceCounterConfiguration": [] }, "syslogEvents": { "syslogEventConfiguration": { 'LOG_AUTH': 'LOG_DEBUG', 'LOG_AUTHPRIV': 'LOG_DEBUG', 'LOG_CRON': 'LOG_DEBUG', 'LOG_DAEMON': 'LOG_DEBUG', 'LOG_FTP': 'LOG_DEBUG', 'LOG_KERN': 'LOG_DEBUG', 'LOG_LOCAL0': 'LOG_DEBUG', 'LOG_LOCAL1': 'LOG_DEBUG', 'LOG_LOCAL2': 'LOG_DEBUG', 'LOG_LOCAL3': 'LOG_DEBUG', 'LOG_LOCAL4': 'LOG_DEBUG', 'LOG_LOCAL5': 'LOG_DEBUG', 'LOG_LOCAL6': 'LOG_DEBUG', 'LOG_LOCAL7': 'LOG_DEBUG', 'LOG_LPR': 'LOG_DEBUG', 'LOG_MAIL': 'LOG_DEBUG', 'LOG_NEWS': 'LOG_DEBUG', 'LOG_SYSLOG': 'LOG_DEBUG', 'LOG_USER': 'LOG_DEBUG', 'LOG_UUCP': 'LOG_DEBUG' } } } } } each_perf_counter_cfg_template = { "unit": "__TO_BE_FILLED__", "type": "builtin", "class": "__TO_BE_REPLACED_BY_CODE", "counter": "__TO_BE_REPLACED_BY_CODE__", "counterSpecifier": "__TO_BE_REPLACED_BY_CODE__", "annotation": "__TO_BE_FILLED__", # Needs to be assigned a new instance to avoid shallow copy # [ # { # "locale": "en-us", # "displayName": "__TO_BE_FILLED__" # } # ], "condition": "__TO_BE_FILLED__" } perf_counter_cfg_list = pub_settings['ladCfg']['diagnosticMonitorConfiguration']['performanceCounters']['performanceCounterConfiguration'] units_and_names = self.extract_perf_counter_units_and_names_from_metrics_def_sample() for class_name in BProvider._builtIns: for lad_counter_name, scx_counter_name in BProvider._builtIns[class_name].iteritems(): perf_counter_cfg = dict(each_perf_counter_cfg_template) perf_counter_cfg['class'] = class_name perf_counter_cfg['counter'] = lad_counter_name counter_specifier = '/builtin/{0}/{1}'.format(class_name, lad_counter_name) perf_counter_cfg['counterSpecifier'] = counter_specifier perf_counter_cfg['condition'] = BProvider.default_condition(class_name) if not perf_counter_cfg['condition']: del perf_counter_cfg['condition'] counter_specifier_with_scx_name = '/builtin/{0}/{1}'.format(class_name.title(), scx_counter_name) if counter_specifier_with_scx_name in units_and_names: perf_counter_cfg['unit'] = units_and_names[counter_specifier_with_scx_name]['unit'] perf_counter_cfg['annotation'] = [{ 'displayName': units_and_names[counter_specifier_with_scx_name]['displayName'], 'locale': 'en-us' }] else: # Use some ad hoc logic to auto-fill missing values (all from FileSystem class) perf_counter_cfg['unit'] = self.inferred_unit_name_from_counter_name(scx_counter_name) perf_counter_cfg['annotation'] = [{ 'displayName': self.inferred_display_name_from_class_counter_names(class_name, scx_counter_name), 'locale': 'en-us' }] perf_counter_cfg_list.append(perf_counter_cfg) actual = json.dumps(pub_settings, sort_keys=True, indent=2) print actual # Uncomment the following 2 lines when generating expected JSON file (of course after validating the actual) #with open('lad_2_3_compatible_portal_pub_settings.json', 'w') as f: # f.write(actual) with open('lad_2_3_compatible_portal_pub_settings.json') as f: expected = f.read() self.assertEqual(json.dumps(json.loads(expected), sort_keys=True), json.dumps(json.loads(actual), sort_keys=True)) to_be_filled = re.findall(r'"__.*?__"', actual) self.assertEqual(2, len(to_be_filled)) self.assertIn('"__DIAGNOSTIC_STORAGE_ACCOUNT__"', to_be_filled) self.assertIn('"__VM_RESOURCE_ID__"', to_be_filled) def inferred_unit_name_from_counter_name(self, scx_counter_name): if 'Percent' in scx_counter_name: return 'Percent' if re.match(r'Bytes.*PerSecond', scx_counter_name): return 'BytesPerSecond' # According to the ACIS-pulled metric definitions sample... if 'PerSecond' in scx_counter_name: return 'CountPerSecond' # Again according to the ACIS-pulled metric defs sample... if scx_counter_name in BProvider._scaling['memory'] or scx_counter_name in BProvider._scaling['filesystem']: return 'Bytes' # Scaled MiB to Bytes counters, so use Bytes as unit raise Exception("Can't infer unit name from scx counter name ({0})".format(scx_counter_name)) def inferred_display_name_from_class_counter_names(self, class_name, scx_counter_name): desc = scx_counter_name desc = desc.replace('PerSecond', '/sec') desc = ' '.join([word.lower() for word in re.findall('[A-Z]+[^A-Z]*', desc)]) desc = desc.replace('percent', '%').replace('megabytes', 'space') return '{0} {1}'.format(class_name.title(), desc) def extract_perf_counter_units_and_names_from_metrics_def_sample(self): """ Another utility function that extracts perf counter units and display names from an Azure metrics definition sample file (not included in the repo). Again this is to be used only manually under the desired environment when needed. :return: Dictionary of counter specifier to unit/displayName map. """ results = {} metric_definitions = {} with open('lad_2_3_metric_definitions_sample.json') as f: metric_definitions = json.load(f) for dict_item in metric_definitions['value']: # This is a list of dictionaries for all metrics # E.g., '\\Memory\\AvailableMemory' to '/builtin/Memory/AvailableMemory' # Also, Azure Insights uses 'PhysicalDisk' and 'NetworkInterface' instead of 'Disk' and 'Network', # so replace them as well. counter_specifier = '/builtin{0}'.format(dict_item['name']['value'].replace('\\', '/') .replace('PhysicalDisk', 'Disk') .replace('NetworkInterface', 'Network')) display_name = dict_item['name']['localizedValue'] # E.g., 'Memory available' unit = dict_item['unit'] # E.g., 'Bytes' results[counter_specifier] = { 'unit': unit, 'displayName': display_name } return results if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/tests/test_commonActions.py ================================================ import unittest import os import errno import platform import time import string import random import DistroSpecific from Utils.WAAgentUtil import waagent class TestCommonActions(unittest.TestCase): _pid = os.getpid() _sequence = 0 _messages = [] _distro = None def make_temp_filename(self): self._sequence += 1 return '/tmp/TestCommonActions_{0}_{1}_{2}'.format(self._pid, time.time(), self._sequence) def log(self, message): self._messages.append(message) @staticmethod def random_string(size, charset=string.ascii_uppercase + string.digits): return ''.join(random.SystemRandom().choice(charset) for _ in range(size)) def setUp(self): dist = platform.dist() self._messages = [] self._distro = DistroSpecific.get_distro_actions(dist[0], dist[1], self.log) def tearDown(self): pass def test_log_run_get_output_silent_success(self): (error, results) = self._distro.log_run_get_output('/bin/true') self.assertEqual(error, 0) self.assertEqual(results, '') def test_log_run_get_output_success(self): expected = TestCommonActions.random_string(50) + '\n' filename = self.make_temp_filename() with open(filename, 'w') as f: f.write(expected) (error, results) = self._distro.log_run_get_output('cat {0}'.format(filename)) os.remove(filename) self.assertEqual(results, expected) self.assertEqual(error, 0) def test_log_run_get_output_failure(self): bad_file= '/bin/ThIsDoEsNoTeXiSt' (error, results) = self._distro.log_run_get_output(bad_file) self.assertEqual(127, error) self.assertIn(bad_file, results) # Should be an error message talking about the non-existent file def test_log_run_ignore_output(self): filename = self.make_temp_filename() try: os.remove(filename) except OSError as e: if e.errno != errno.ENOENT: self.fail("Pre-test os.delete({0}) returned {1}".format(filename, errno.errorcode[e.errno])) error = self._distro.log_run_ignore_output("touch {0}".format(filename)) self.assertEqual(error, 0) try: os.remove(filename) except IOError as e: if e.errno == errno.ENOENT: self.fail("Test command did not properly execute") else: self.fail("Post-test os.delete({0}) returned {1}".format(filename, errno.errorcode[e.errno])) def test_log_run_with_timeout_force_timeout(self): (status, output) = self._distro.log_run_with_timeout("sleep 10; echo sleep done", timeout=5) self.assertEqual(output, 'Process timeout\n') self.assertEqual(status, 1) def test_log_run_with_timeout_without_timeout(self): (status, output) = self._distro.log_run_with_timeout("echo success; exit 2", timeout=5) self.assertEqual(output, 'success\n') self.assertEqual(status, 2) def test_log_run_multiple_cmds(self): expected = 'foo\nbar\n' cmds = ('echo foo', 'echo bar') error, output = self._distro.log_run_multiple_cmds(cmds, False) self.assertEqual(error, 0) self.assertEqual(output, expected) def test_log_run_multiple_cmds_no_timeout(self): expected = 'foo\nbar\n' cmds = ('echo foo', 'echo bar') error, output = self._distro.log_run_multiple_cmds(cmds, True) self.assertEqual(error, 0) self.assertEqual(output, expected) def test_log_run_multiple_cmds_partial_timeout(self): expected = 'Process timeout\nbar\n' cmds = ('sleep 30; echo foo', 'echo bar') error, output = self._distro.log_run_multiple_cmds(cmds, True, 5) self.assertEqual(error, 1) self.assertEqual(output, expected) if __name__ == '__main__': waagent.LoggerInit('waagent.verbose.log', None, True) unittest.main() ================================================ FILE: Diagnostic/tests/test_lad_config_all.py ================================================ # Make LadConfigAll class unittest-able here. # To achieve that, the following were done: # - Mock VM's cert/prv key files (w/ thumbprint) that's used for decrypting the extensions's protected settings # and for encrypting storage key/SAS token in mdsd XML file # - Mock a complete LAD extension's handler setting (that includes protected settings and public settings). # - Mock RunGetOutput for external command executions. # - Mock any other things that are necessary! # It'd be easiest to create a test VM w/ LAD enabled and copy out necessary files to here to be used for this test. # The test VM was destroyed immediately. A test storage account was used and deleted immediately. # TODO Try to generate priv key/cert/storage shared key dynamically here. import binascii import json import os import unittest from xml.etree import ElementTree as ET # This test suite uses xmlunittest package. Install it by running 'pip install xmlunittest'. # Documentation at http://python-xmlunittest.readthedocs.io/en/latest/ from xmlunittest import XmlTestMixin from Utils.lad_ext_settings import * # The following line will work on an Azure Linux VM (where waagent is installed), but fail on a non-Azure Linux VM # (because of no waagent). It's because lad_config_all.py will import misc_helpers.py, which will try to import # waagent from WAAgentUtil.py. # To work around this on a non-Azure Linux VM, define PYTHONPATH env var # with "azure-linux-extensions/Common/WALinuxAgent-2.0.16" included in it. # E.g., run 'export PYTHONPATH=/azure-linux-extensions/Common/WALinuxAgent-2.0.16' before running this test. # # Also, if you're trying to execute this test on a Windows system rather than under Linux, the waagent code relies on # three Linux-only modules you'll need to mock out: crypt(crypt()), pwd(getpwnam()), and fcntl(ioctl()). from lad_config_all import * # Mocked waagent/LAD dir/files test_waagent_dir = os.path.join(os.path.dirname(__file__), 'var_lib_waagent') test_lad_dir = os.path.join(test_waagent_dir, 'lad_dir') test_lad_settings_logging_json_file = os.path.join(test_lad_dir, 'config', 'lad_settings_logging.json') test_lad_settings_metric_json_file = os.path.join(test_lad_dir, 'config', 'lad_settings_metric.json') # Mocked functions # We're not really interested in testing the ability to decrypt the private settings; that's tested elsewhere. # Instead, we assume the test handlerSettings object contains the decrypted Private settings already, since we just # need to test our ability to read and manipulate those settings. def decrypt_protected_settings(handlerSettings): pass def print_content_with_header(header_text, content): header = '>>>>> ' + header_text + ' >>>>>' print header print content print '<' * len(header) print def mock_fetch_uuid(): return "DEADBEEF-0000-1111-2222-77DEADBEEF77" def mock_encrypt_secret(cert, secret): # Encode secret w/ binascii.hex() to avoid invalid chars in XML. # The actual/real return value of the non-mocked encrypt_secret() is in that form. # We still keep the "ENCRYPTED(...)" part here to show that clearly in our test outputs. secret = binascii.b2a_hex(secret).upper() return "ENCRYPTED({0},{1})".format(cert, secret) def mock_log_info(msg): print 'LOG:', msg def mock_log_error(msg): print 'ERROR:', msg def load_test_config(filename): """ Load a test configuration into a LadConfigAll object :param filename: Name of config file :rtype: LadConfigAll :return: Loaded configuration """ with open(filename) as f: handler_settings = json.loads(f.read())['runtimeSettings'][0]['handlerSettings'] decrypt_protected_settings(handler_settings) lad_settings = LadExtSettings(handler_settings) return LadConfigAll(lad_settings, test_lad_dir, '', 'test_lad_deployment_id', mock_fetch_uuid, mock_encrypt_secret, mock_log_info, mock_log_error) class LadConfigAllTest(unittest.TestCase, XmlTestMixin): def test_lad_config_all_logging_only(self): """ Perform basic LadConfigAll object tests with logging-only configs, like generating various configs and validating them. """ lad_cfg = load_test_config(test_lad_settings_logging_json_file) result, msg = lad_cfg.generate_all_configs() self.assertTrue(result, 'Config generation failed: ' + msg) with open(os.path.join(test_lad_dir, 'xmlCfg.xml')) as f: mdsd_xml_cfg = f.read() print_content_with_header('Generated mdsd XML cfg for logging-only LAD settings', mdsd_xml_cfg) self.assertTrue(mdsd_xml_cfg, 'Empty mdsd XML config is invalid!') rsyslog_cfg = lad_cfg.get_rsyslog_config() print_content_with_header('Generated rsyslog cfg', rsyslog_cfg) self.assertTrue(rsyslog_cfg, 'Empty rsyslog cfg is invalid') syslog_ng_cfg = lad_cfg.get_syslog_ng_config() print_content_with_header('Generated syslog-ng cfg', syslog_ng_cfg) self.assertTrue(syslog_ng_cfg, 'Empty syslog-ng cfg is invalid') fluentd_out_mdsd_cfg = lad_cfg.get_fluentd_out_mdsd_config() print_content_with_header('Generated fluentd out_mdsd cfg', fluentd_out_mdsd_cfg) self.assertTrue(fluentd_out_mdsd_cfg, 'Empty fluentd out_mdsd cfg is invalid') fluentd_syslog_src_cfg = lad_cfg.get_fluentd_syslog_src_config() print_content_with_header('Generated fluentd syslog src cfg', fluentd_syslog_src_cfg) self.assertTrue(fluentd_syslog_src_cfg, 'Empty fluentd syslog src cfg is invalid') fluentd_tail_src_cfg = lad_cfg.get_fluentd_tail_src_config() print_content_with_header('Generated fluentd tail src cfg', fluentd_tail_src_cfg) self.assertTrue(fluentd_tail_src_cfg, 'Empty fluentd tail src cfg is invalid') def test_lad_config_all_metric_only(self): """ Perform basic LadConfigAll object tests with metric-only configs, like generating various configs and validating them. """ lad_cfg = load_test_config(test_lad_settings_metric_json_file) result, msg = lad_cfg.generate_all_configs() self.assertTrue(result, 'Config generation failed: ' + msg) with open(os.path.join(test_lad_dir, 'xmlCfg.xml')) as f: mdsd_xml_cfg = f.read() print_content_with_header('Generated mdsd XML cfg for metric-only LAD settings', mdsd_xml_cfg) self.assertTrue(mdsd_xml_cfg, 'Empty mdsd XML config is invalid!') # Verify using xmlunittests root = self.assertXmlDocument(mdsd_xml_cfg) expected_xml_str = """ test_lad_deployment_id /builtin/filesystem/freespace(/mnt) /builtin/filesystem/usedspace /builtin/processor/PercentProcessorTime ENCRYPTED(B175B535DFE9F93659E5AFD893BF99BBF9DF28A5.crt,68747470733A2F2F66616B65267361732575726C3B31) ENCRYPTED(B175B535DFE9F93659E5AFD893BF99BBF9DF28A5.crt,68747470733A2F2F66616B65267361732575726C3B32) """ # The following is at least insensitive to whitespaces... Also it's way more complicated # to create XPaths for this, so just use the following API. self.assertXmlEquivalentOutputs(mdsd_xml_cfg, expected_xml_str) def test_update_metric_collection_settings(self): test_config = \ { "diagnosticMonitorConfiguration": { "foo": "bar", "eventVolume": "Large", "sinksConfig": { "sink": [ { "name": "sink1", "type": "EventHub", "sasURL": "https://sbnamespace.servicebus.windows.net/raw?sr=https%3a%2f%2fsb" "namespace.servicebus.windows.net%2fraw%2f&sig=SIGNATURE%3d" "&se=1804371161&skn=writer" } ] }, "metrics": { "resourceId": "/subscriptions/1111-2222-3333-4444/resourcegroups/RG1/compute/foo", "metricAggregation": [ {"scheduledTransferPeriod": "PT5M"}, {"scheduledTransferPeriod": "PT1H"}, ] }, "performanceCounters": { "sinks": "sink1", "performanceCounterConfiguration": [ { "type": "builtin", "class": "Processor", "counter": "PercentIdleTime", "counterSpecifier": "/builtin/Processor/PercentIdleTime", "condition": "IsAggregate=TRUE", "sampleRate": "PT15S", "unit": "Percent", "annotation": [ { "displayName": "Aggregate CPU %idle time", "locale": "en-us" } ] } ] }, "syslogEvents": { "syslogEventConfiguration": { "LOG_LOCAL1": "LOG_INFO", "LOG_MAIL": "LOG_FATAL" } } }, "sampleRateInSeconds": 60 } test_sinks_config = \ { "sink": [ { "name": "sink1", "type": "EventHub", "sasURL": "https://sbnamespace.servicebus.windows.net/raw?sr=https%3a%2f%2fsb" "namespace.servicebus.windows.net%2fraw%2f&sig=SIGNATURE%3d" "&se=1804371161&skn=writer" } ] } configurator = load_test_config(test_lad_settings_logging_json_file) configurator._sink_configs.insert_from_config(test_sinks_config) configurator._update_metric_collection_settings(test_config) print ET.tostring(configurator._mdsd_config_xml_tree.getroot()) if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/tests/test_lad_ext_settings.py ================================================ import json import unittest from Utils.lad_ext_settings import * class LadExtSettingsTest(unittest.TestCase): def setUp(self): handler_settings_sample_in_str = """ { "protectedSettings": { "storageAccountName": "mystgacct", "storageAccountSasToken": "SECRET", "sinksConfig": { "sink": [ { "type": "JsonBlob", "name": "JsonBlobSink1" }, { "type": "JsonBlob", "name": "JsonBlobSink2" }, { "type": "EventHub", "name": "EventHubSink1", "sasURL": "SECRET" }, { "type": "EventHub", "name": "EventHubSink2", "sasURL": "SECRET" } ] } }, "publicSettings": { "StorageAccount": "mystgacct", "sampleRateInSeconds": 15, "fileLogs": [ { "sinks": "EventHubSink1", "file": "/var/log/myladtestlog" } ] } } """ self._lad_settings = LadExtSettings(json.loads(handler_settings_sample_in_str)) def test_redacted_handler_settings(self): expected = """ { "protectedSettings": { "sinksConfig": { "sink": [ { "name": "JsonBlobSink1", "type": "JsonBlob" }, { "name": "JsonBlobSink2", "type": "JsonBlob" }, { "name": "EventHubSink1", "sasURL": "REDACTED_SECRET", "type": "EventHub" }, { "name": "EventHubSink2", "sasURL": "REDACTED_SECRET", "type": "EventHub" } ] }, "storageAccountName": "mystgacct", "storageAccountSasToken": "REDACTED_SECRET" }, "publicSettings": { "StorageAccount": "mystgacct", "fileLogs": [ { "file": "/var/log/myladtestlog", "sinks": "EventHubSink1" } ], "sampleRateInSeconds": 15 } } """ actual_json = json.loads(self._lad_settings.redacted_handler_settings()) print json.dumps(actual_json, sort_keys=True, indent=2) self.assertEqual(json.dumps(json.loads(expected), sort_keys=True), json.dumps(actual_json, sort_keys=True)) # Validate that the original wasn't modified (that is, redaction should be on a deep copy) print "===== Original handler setting (shouldn't be redacted, must be different from the deep copy) =====" print json.dumps(self._lad_settings.get_handler_settings(), sort_keys=True, indent=2) self.assertNotEqual(json.dumps(self._lad_settings.get_handler_settings(), sort_keys=True), json.dumps(actual_json, sort_keys=True)) if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/tests/test_lad_logging_config.py ================================================ import unittest import json from xml.etree import ElementTree as ET # This test suite uses xmlunittest package. Install it by running 'pip install xmlunittest'. # Documentation at http://python-xmlunittest.readthedocs.io/en/latest/ from xmlunittest import XmlTestMixin from Utils.lad_logging_config import * from Utils.omsagent_util import get_syslog_ng_src_name from Utils.mdsd_xml_templates import entire_xml_cfg_tmpl import Utils.LadDiagnosticUtil as LadUtil from tests.test_lad_config_all import mock_encrypt_secret class LadLoggingConfigTest(unittest.TestCase, XmlTestMixin): def setUp(self): """ Create LadLoggingConfig objects for use by test cases """ # "syslogEvents" LAD config example syslogEvents_json_ext_settings = """ { "sinks": "SyslogJsonBlob,SyslogEventHub", "syslogEventConfiguration": { "LOG_LOCAL0": "LOG_CRIT", "LOG_USER": "LOG_ERR" } } """ # "fileLogs" LAD config example fileLogs_json_ext_settings = """ [ { "file": "/var/log/mydaemonlog1", "table": "MyDaemon1Events", "sinks": "Filelog1JsonBlob,FilelogEventHub" }, { "file": "/var/log/mydaemonlog2", "table": "MyDaemon2Events", "sinks": "Filelog2JsonBlob" } ] """ # "sinksConfig" LAD config example sinksConfig_json_ext_settings = """ { "sink": [ { "name": "SyslogEventHub", "type": "EventHub", "sasURL": "https://fake&sas%url;for_syslog_eh" }, { "name": "SyslogJsonBlob", "type": "JsonBlob" }, { "name": "FilelogEventHub", "type": "EventHub", "sasURL": "https://fake&sas%url;for_filelog_eh" }, { "name": "Filelog1JsonBlob", "type": "JsonBlob" }, { "name": "Filelog2JsonBlob", "type": "JsonBlob" } ] } """ sinksConfig = LadUtil.SinkConfiguration() sinksConfig.insert_from_config(json.loads(sinksConfig_json_ext_settings)) syslogEvents = json.loads(syslogEvents_json_ext_settings) mock_pkey_path = "/waagent/dir/mock_pkey.prv" mock_cert_path = "/waagent/dir/mock_cert.crt" self.cfg_syslog = LadLoggingConfig(syslogEvents, None, sinksConfig, mock_pkey_path, mock_cert_path, mock_encrypt_secret) fileLogs = json.loads(fileLogs_json_ext_settings) self.cfg_filelog = LadLoggingConfig(None, fileLogs, sinksConfig, mock_pkey_path, mock_cert_path, mock_encrypt_secret) self.cfg_none = LadLoggingConfig(None, None, sinksConfig, mock_pkey_path, mock_cert_path, mock_encrypt_secret) # XPaths representations of expected XML outputs, for use with xmlunittests package self.oms_syslog_expected_xpaths = ('./Sources/Source[@name="mdsd.syslog" and @dynamic_schema="true"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.syslog"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.syslog"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="LinuxSyslog" and @priority="High"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.syslog"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="SyslogJsonBlob" and @priority="High" and @storeType="JsonBlob"]', './EventStreamingAnnotations/EventStreamingAnnotation[@name="mdsd.syslog"]/EventPublisher/Key', # TODO Perform CDATA validation ) self.oms_filelog_expected_xpaths = ('./Sources/Source[@name="mdsd.filelog.var.log.mydaemonlog1" and @dynamic_schema="true"]', './Sources/Source[@name="mdsd.filelog.var.log.mydaemonlog2" and @dynamic_schema="true"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog1"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog1"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="MyDaemon1Events" and @priority="High"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog1"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="Filelog1JsonBlob" and @priority="High" and @storeType="JsonBlob"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog2"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog2"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="MyDaemon2Events" and @priority="High"]', './Events/MdsdEvents/MdsdEventSource[@source="mdsd.filelog.var.log.mydaemonlog2"]/RouteEvent[@dontUsePerNDayTable="true" and @eventName="Filelog2JsonBlob" and @priority="High" and @storeType="JsonBlob"]', './EventStreamingAnnotations/EventStreamingAnnotation[@name="mdsd.filelog.var.log.mydaemonlog1"]/EventPublisher/Key', # TODO Perform CDATA validation ) def test_oms_syslog_mdsd_configs(self): """ Test whether syslog/syslog-ng config (for use with omsagent) is correctly generated for both 'syslogEvents' and 'syslogCfg' settings. Also test whether the coresponding mdsd XML config is correctly generated. """ # Basic config (single dest table) self.__helper_test_oms_syslog_mdsd_configs(self.cfg_syslog, self.oms_syslog_expected_xpaths) # No syslog config case self.assertFalse(self.cfg_none.get_rsyslog_config()) self.assertFalse(self.cfg_none.get_syslog_ng_config()) self.assertFalse(self.cfg_none.get_mdsd_syslog_config()) def __helper_test_oms_syslog_mdsd_configs(self, cfg, expected_xpaths): """ Helper for test_oms_rsyslog(). :param cfg: SyslogMdsdConfig object containing syslog config """ print '=== Actual oms rsyslog config output ===' oms_rsyslog_config = cfg.get_rsyslog_config() print oms_rsyslog_config print '========================================' lines = oms_rsyslog_config.strip().split('\n') # Item (line) count should match self.assertEqual(len(cfg._fac_sev_map), len(lines)) # Each line should be correctly formatted for l in lines: self.assertRegexpMatches(l, r"\w+\.\w+\s+@127\.0\.0\.1:%SYSLOG_PORT%") # For each facility-severity, there should be corresponding line. for fac, sev in cfg._fac_sev_map.iteritems(): index = oms_rsyslog_config.find('{0}.{1}'.format(syslog_name_to_rsyslog_name(fac), syslog_name_to_rsyslog_name(sev))) self.assertGreaterEqual(index, 0) print "*** Actual output verified ***\n" print '=== Actual oms syslog-ng config output ===' oms_syslog_ng_config = cfg.get_syslog_ng_config() print oms_syslog_ng_config print '==========================================' lines = oms_syslog_ng_config.strip().split('\n') # Item (line) count should match self.assertGreaterEqual(len(lines), len(cfg._fac_sev_map)) # Each line should be correctly formatted for l in lines: self.assertRegexpMatches(l, r'log \{{ source\({0}\); filter\(f_LAD_oms_f_\w+\); ' r'filter\(f_LAD_oms_ml_\w+\); destination\(d_LAD_oms\); \}}' .format(get_syslog_ng_src_name())) # For each facility-severity, there should be corresponding line. for fac, sev in cfg._fac_sev_map.iteritems(): index = oms_syslog_ng_config.find('log {{ source({0}); filter(f_LAD_oms_f_{1}); filter(f_LAD_oms_ml_{2}); ' 'destination(d_LAD_oms); }}'.format(get_syslog_ng_src_name(), syslog_name_to_rsyslog_name(fac), syslog_name_to_rsyslog_name(sev))) self.assertGreaterEqual(index, 0) print "*** Actual output verified ***\n" print '=== Actual oms syslog mdsd XML output ===' xml = cfg.get_mdsd_syslog_config() print xml print '=========================================' root = self.assertXmlDocument(xml) self.assertXpathsOnlyOne(root, expected_xpaths) print "*** Actual output verified ***\n" def test_oms_filelog_mdsd_config(self): """ Test whether mdsd XML config for LAD fileLog settings is correctly generated. """ print '=== Actual oms filelog mdsd XML config output ===' xml = self.cfg_filelog.get_mdsd_filelog_config() print xml print '=================================================' root = self.assertXmlDocument(xml) self.assertXpathsOnlyOne(root, self.oms_filelog_expected_xpaths) print "*** Actual output verified ***\n" # Other configs should be all '' self.assertFalse(self.cfg_syslog.get_mdsd_filelog_config()) self.assertFalse(self.cfg_none.get_mdsd_filelog_config()) def __helper_test_oms_fluentd_config(self, header_text, expected, actual): header = "=== Actual output of {0} ===".format(header_text) print header print actual print '=' * len(header) # TODO BADBAD exact string matching... self.assertEqual(expected, actual) pass def test_oms_fluentd_configs(self): """ Test whether fluentd syslog/tail source configs & out_mdsd config are correctly generated. """ actual = self.cfg_syslog.get_fluentd_syslog_src_config() expected = """ type syslog port %SYSLOG_PORT% bind 127.0.0.1 protocol_type udp include_source_host true tag mdsd.syslog # Generate fields expected for existing mdsd syslog collection schema. type record_transformer enable_ruby # Fields for backward compatibility with Azure Shoebox V1 (Table storage) Ignore "syslog" Facility ${tag_parts[2]} Severity ${tag_parts[3]} EventTime ${time.strftime('%Y-%m-%dT%H:%M:%S%z')} SendingHost ${record["source_host"]} Msg ${record["message"]} # Rename 'host' key, as mdsd will add 'Host' for Azure Table and it'll be confusing hostname ${record["host"]} remove_keys host,message,source_host # Renamed (duplicated) fields, so just remove """ self.__helper_test_oms_fluentd_config('fluentd basic syslog src config', expected, actual) actual = self.cfg_filelog.get_fluentd_syslog_src_config() expected = '' self.__helper_test_oms_fluentd_config('fluentd syslog src config for no syslog', expected, actual) actual = self.cfg_syslog.get_fluentd_out_mdsd_config() expected_out_mdsd_cfg_template = r""" # Output to mdsd type mdsd log_level warn djsonsocket /var/run/mdsd/lad_mdsd_djson.socket # Full path to mdsd dynamic json socket file acktimeoutms 5000 # max time in milli-seconds to wait for mdsd acknowledge response. If 0, no wait. {optional_lines} num_threads 1 buffer_chunk_limit 1000k buffer_type file buffer_path /var/opt/microsoft/omsagent/LAD/state/out_mdsd*.buffer buffer_queue_limit 128 flush_interval 10s retry_limit 3 retry_wait 10s """ out_mdsd_optional_config_lines = r""" mdsd_tag_regex_patterns [ "^mdsd\\.syslog" ] # fluentd tag patterns whose match will be used as mdsd source name """ self.__helper_test_oms_fluentd_config('fluentd out_mdsd config for basic syslog cfg', expected_out_mdsd_cfg_template.format( optional_lines=out_mdsd_optional_config_lines), actual) actual = self.cfg_filelog.get_fluentd_filelog_src_config() expected = """ # For all monitored files @type tail path /var/log/mydaemonlog1,/var/log/mydaemonlog2 pos_file /var/opt/microsoft/omsagent/LAD/tmp/filelogs.pos tag mdsd.filelog.* format none message_key Msg # LAD uses "Msg" as the field name # Add FileTag field (existing LAD behavior) @type record_transformer FileTag ${tag_suffix[2]} """ self.__helper_test_oms_fluentd_config('fluentd tail src config for fileLogs', expected, actual) actual = self.cfg_filelog.get_fluentd_out_mdsd_config() self.__helper_test_oms_fluentd_config('fluentd out_mdsd config for filelog only (no syslog) cfg', expected_out_mdsd_cfg_template.format(optional_lines=''), actual) actual = self.cfg_none.get_fluentd_out_mdsd_config() self.__helper_test_oms_fluentd_config('fluentd out_mdsd config for blank cfg (syslog disabled)', expected_out_mdsd_cfg_template.format(optional_lines=''), actual) def test_copy_schema_source_mdsdevent_eh_url_elems(self): """ Tests whether copy_schema_source_mdsdevent_eh_url_elems() works fine. Uses oms_syslog_expected_xpaths and oms_filelog_expected_xpaths XPath lists to test the operation. """ xml_string_srcs = [ self.cfg_syslog.get_mdsd_syslog_config(), self.cfg_filelog.get_mdsd_filelog_config() ] dst_xml_tree = ET.ElementTree(ET.fromstring(entire_xml_cfg_tmpl)) map(lambda x: copy_source_mdsdevent_eh_url_elems(dst_xml_tree, x), xml_string_srcs) print '=== mdsd config XML after combining syslog/filelogs XML configs ===' xml = ET.tostring(dst_xml_tree.getroot()) print xml print '===================================================================' # Verify using xmlunittests root = self.assertXmlDocument(xml) self.assertXpathsOnlyOne(root, self.oms_syslog_expected_xpaths) self.assertXpathsOnlyOne(root, self.oms_filelog_expected_xpaths) print "*** Actual output verified ***\n" if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/tests/var_lib_waagent/lad_dir/config/lad_settings_logging.json ================================================ { "runtimeSettings": [ { "handlerSettings": { "publicSettings": { "StorageAccount": "ladunittestdiag487", "ladCfg": { "diagnosticMonitorConfiguration": { "syslogEvents": { "sinks": "SyslogJsonBlob,SyslogEventHub", "syslogEventConfiguration": { "LOG_USER": "LOG_ERR", "LOG_LOCAL0": "LOG_CRIT" } } } }, "fileLogs" : [ { "file": "/var/log/mydaemonlog1", "table": "MyDaemon1Events", "sinks": "Filelog1JsonBlob,FilelogEventHub" }, { "file": "/var/log/mydaemonlog2", "sinks": "Filelog2JsonBlob" } ], "perfCfg": [ {"query": "SELECT PercentAvailableMemory, AvailableMemory, UsedMemory, PercentUsedSwap FROM SCX_MemoryStatisticalInformation", "table": "LinuxMemory"}, {"query": "SELECT PercentProcessorTime, PercentIOWaitTime, PercentIdleTime FROM SCX_ProcessorStatisticalInformation WHERE Name='_TOTAL'", "table": "LinuxCpu"}, {"query": "SELECT AverageWriteTime,AverageReadTime,ReadBytesPerSecond,WriteBytesPerSecond FROM SCX_DiskDriveStatisticalInformation WHERE Name='_TOTAL'", "table": "LinuxDisk"} ] }, "protectedSettingsCertThumbprint": "B175B535DFE9F93659E5AFD893BF99BBF9DF28A5", "protectedSettings": { "storageAccountName":"ladunittestfakeaccount", "storageAccountSasToken":"NOT_A_REAL_TOKEN", "storageAccountEndPoint":"https://core.windows.net/", "sinksConfig": { "sink": [ { "sasURL": "https://fake_sas_url_1", "type": "EventHub", "name": "SyslogEventHub" }, { "type": "JsonBlob", "name": "SyslogJsonBlob" }, { "sasURL": "https://fake_sas_url_2", "type": "EventHub", "name": "FilelogEventHub" }, { "type": "JsonBlob", "name": "Filelog1JsonBlob" }, { "type": "JsonBlob", "name": "Filelog2JsonBlob" } ] } } } } ] } ================================================ FILE: Diagnostic/tests/var_lib_waagent/lad_dir/config/lad_settings_metric.json ================================================ { "runtimeSettings": [ { "handlerSettings": { "protectedSettings": { "storageAccountEndPoint": "https://core.windows.net/", "storageAccountSasToken": "?NOT_A_REAL_TOKEN", "storageAccountName": "ladunittestfakeaccount", "sinksConfig": { "sink": [ { "sasURL": "https://fake&sas%url;1", "type": "EventHub", "name": "LinuxMemoryEventHub" }, { "type": "JsonBlob", "name": "SyslogJsonBlob" }, { "sasURL": "https://fake&sas%url;2", "type": "EventHub", "name": "ProcessorInfoEventHub" }, { "type": "JsonBlob", "name": "ProcessorInfoJsonBlob" }, { "type": "JsonBlob", "name": "FileSystemJsonBlob" } ] } }, "protectedSettingsCertThumbprint": "B175B535DFE9F93659E5AFD893BF99BBF9DF28A5", "publicSettings": { "ladCfg": { "diagnosticMonitorConfiguration": { "eventVolume": "Large", "metrics": { "resourceId": "ladtest_resource_id", "metricAggregation": [ { "scheduledTransferPeriod": "PT1H" }, { "scheduledTransferPeriod": "PT1M" } ] }, "performanceCounters": { "performanceCounterConfiguration": [ { "class": "Processor", "condition": "IsAggregate=TRUE", "annotation": [ { "displayName": "Aggregate CPU %utilization", "locale": "en-us" } ], "counterSpecifier": "/builtin/processor/PercentProcessorTime", "counter": "percentprocessorTime", "type": "builtin", "unit": "Percent" }, { "class": "Filesystem", "condition": "Name=\"/\"", "annotation": [ { "displayName": "Used disk space on /", "locale": "en-us" } ], "counterSpecifier": "/builtin/filesystem/usedspace", "counter": "UsedSpace", "type": "builtin", "unit": "Bytes" }, { "class": "Filesystem", "condition": "Name='/'", "annotation": [ { "displayName": "Free disk space on /mnt", "locale": "en-us" } ], "counterSpecifier": "/builtin/filesystem/freespace(/mnt)", "counter": "FreeSpace", "type": "builtin", "unit": "Bytes" } ] } } }, "perfCfg": [ { "query": "SELECT PercentAvailableMemory, PercentUsedSwap FROM SCX_MemoryStatisticalInformation", "table": "LinuxMemory", "sinks": "LinuxMemoryEventHub" }, { "query": "SELECT PercentProcessorTime FROM SCX_ProcessorStatisticalInformation", "sinks": "ProcessorInfoJsonBlob,ProcessorInfoEventHub", "frequency": 60 }, { "query": "SELECT FreeMegabytes FROM SCX_FileSystemStatisticalInformation", "table": "LinuxFileSystem", "sinks": "FileSystemJsonBlob" } ], "sampleRateInSeconds": 15, "StorageAccount": "ladtest" } } } ] } ================================================ FILE: Diagnostic/tests/watchertests.py ================================================ import unittest import diagnostic import sys import subprocess import os import errno import watcherutil class FStabUnitTests(unittest.TestCase): _watcher = None _datapath = os.path.join(os.getcwd(), 'utdata') def setUp(self): self._watcher = watcherutil.Watcher(sys.stderr, sys.stdout) try: os.mkdir(self._datapath) except OSError as e: if e.errno != errno.EEXIST: raise pass # mount an overlay so that we can make changes to /etc/fstab subprocess.call(['sudo', 'mount', '-t', 'overlayfs', 'overlayfs', '-olowerdir=/etc,upperdir=' + self._datapath, '/etc']) pass def tearDown(self): subprocess.call(['sudo', 'umount', '/etc']) try: os.rmdir(self._datapath) except OSError as e: pass def test_fstab_basic(self): self.assertEqual(self._watcher.handle_fstab(ignore_time=True), 0) def test_fstab_touch(self): subprocess.call(['sudo', 'touch', '/etc/fstab']) self.assertEqual(self._watcher.handle_fstab(ignore_time=True), 0) def addFstabEntry(self, fstabentry): with open(self._datapath + '/fstab', 'w') as f: f.write(fstabentry) f.write('\n') @unittest.skip('Skipping because mount -f fails to detect error') def test_fstab_baduuid(self): self.addFstabEntry('UUID=1111111-1111-1111-1111-111111111111 /test ext4 defaults 0 0') pdb.set_trace() self.assertNotEqual(self._watcher.handle_fstab(ignore_time=True), 0) @unittest.skip('Skipping because mount -f fails to detect error') def test_fstab_baddevicename(self): self.addFstabEntry('/dev/foobar /test ext4 defaults 0 0') self.assertNotEqual(self._watcher.handle_fstab(ignore_time=True), 0) @unittest.skip('Skipping because mount -f fails to detect error') def test_fstab_malformedentry(self): self.addFstabEntry('/test /dev/foobar ext4 defaults 0 0') self.assertNotEqual(self._watcher.handle_fstab(ignore_time=True), 0) def test_fstab_goodentry(self): self.addFstabEntry('/dev/sdb1 /test ext4 defaults 0 0') self.assertEqual(self._watcher.handle_fstab(ignore_time=True), 0) if __name__ == '__main__': unittest.main() ================================================ FILE: Diagnostic/virtual-machines-linux-diagnostic-extension-v3.md ================================================ --- title: Page title that displays in the browser tab and search results | Microsoft Docs description: Article description that will be displayed on landing pages and in most search results services: virtual-machines-linux documentationcenter: dev-center-name author: jasonzio manager: anandram ms.service: virtual-machines-linux ms.devlang: may be required ms.topic: article ms.tgt_pltfrm: vm-linux ms.workload: required ms.date: 04/21/2017 ms.author: jasonzio@microsoft.com --- # Use Linux Diagnostic Extension v3 to monitor metrics and logs ## Introduction The Linux Diagnostic Extension helps a user monitor the health of a Linux VM running on Microsoft Azure. It has the following capabilities: * Collects system performance metrics from the VM and stores them in a specific table in a designated storage account (usually the account in which the VM's boot vhd is stored). * Retrieves log events from syslog and stores them in a specific table in the designated storage account. * Enables users to customize the data metrics that will be collected and uploaded. * Enables users to customize the syslog facilities and severity levels of events that will be collected and uploaded. * Enables users to upload specified log files to a designated storage table. * Supports sending the above data to arbitrary EventHub endpoints and JSON-formatted blobs in the designated storage account. This extension works with both the classic and Resource Manager deployment models. ### Migration from previous versions of the extension The latest version of the extension is **3.0**. **Any old versions (2.x) will be deprecated and may be unpublished on or after 2018-07-31**. This extension introduces breaking changes to the configuration of the extension. One such change was made to improve the security of the extension; as a result, backwards compatibility with 2.x could not be maintained. Also, the Extension Publisher for this extension is different than the publisher for the 2.x versions. In order to migrate from 2.x to this new version of the extension, you must uninstall the old extension (under the old publisher name) and then install the new extension. We strongly recommended you install the extension with automatic minor version upgrade enabled. On classic (ASM) VMs, you can achieve this by specifying '3.*' as the version if you are installing the extension through Azure XPLAT CLI or Powershell. On ARM VMs, you can achieve this by including '"autoUpgradeMinorVersion": true' in the VM deployment template. ## Enable the extension You can enable this extension by using the [Azure portal](https://portal.azure.com/#), Azure PowerShell, or Azure CLI scripts. Use the Azure portal to view performance data directly from the Azure portal: ![image](./media/virtual-machines-linux-diagnostic-extension-v3/graph_metrics.png) This article focuses on how to enable and configure the extension by using Azure CLI commands. Only a subset of the features of the extension can only be configured via the Azure portal, which will ignore (and leave unchanged) the parts of the configuration it does not address. ## Prerequisites * **Azure Linux Agent version 2.2.0 or later**. Note that most Azure VM Linux gallery images include version 2.2.7 or later. You can run **/usr/sbin/waagent -version** to confirm which version is installed on the VM. If the VM is running an older version of the guest agent, you can follow [these instructions on GitHub](https://github.com/Azure/WALinuxAgent "instructions") to update it. * **Azure CLI**. Follow [this guidance for installing CLI](../xplat-cli-install.md) to set up the Azure CLI environment on your machine. After Azure CLI is installed, you can use the **azure** command from your command-line interface (Bash, Terminal, or command prompt) to access the Azure CLI commands. For example: * Run **azure vm extension set --help** for detailed help information. * Run **azure login** to sign in to Azure. * Run **azure vm list** to list all the virtual machines that you have on Azure. * A storage account to store the data. You will need a storage account name that was created previously and an account SAS token to upload the data to your storage. ## Protected Settings This set of configuration information contains sensitive information which should be protected from public view, e.g. storage credentials. These settings are transmitted to and stored by the extension in encrypted form. ```json { "storageAccountName" : "the storage account to receive data", "storageAccountEndPoint": "the URL prefix for the cloud for this account", "storageAccountSasToken": "SAS access token", "mdsdHttpProxy": "HTTP proxy settings", "sinksConfig": { ... } } ``` Name | Value ---- | ----- storageAccountName | The name of the storage account in which data will be written by the extension storageAccountEndPoint | (optional) The endpoint identifying the cloud in which the storage account exists. For the Azure public cloud (which is the default when this setting is not given), this would be [https://core.windows.net](https://core.windows.net); set this appropriately for a storage account in a national cloud. storageAccountSasToken | An [Account SAS token](https://azure.microsoft.com/en-us/blog/sas-update-account-sas-now-supports-all-storage-services/) for Blob and Table services (ss='bt'), containers and objects (srt='co'), which grants add, create, list, update, and write permissions (sp='acluw') mdsdHttpProxy | (optional) HTTP proxy information needed to enable the extension to connect to the specified storage account and endpoint. sinksConfig | (optional) Details of alternative destinations to which metrics and events can be delivered. The specific details of the various data sinks supported by the extension are covered below. You can easily construct the required SAS token through the Azure portal. Select the general-purpose storage account which you want the extension to write, then select "Shared access signature" from the Settings part of the left menu. Make the appropriate choices as described above and click the "Generate SAS" button. ![image](./media/virtual-machines-linux-diagnostic-extension-v3/makeSAS.png) Copy the generated SAS into the storageAccountSasToken field; remove the leading question-mark ("?"). ### sinksConfig ```json "sinksConfig": { "sink": [ { "name": "sinkname", "type": "sinktype", ... }, ... ] }, ``` This section defines additional destinations to which the extension will deliver the information it collects. The "sink" array contains an object for each additional data sink. The object will contain additional attributes as determined by the "type" attribute. Element | Value ------- | ----- name | A string used to refer to this sink elsewhere in the extension configuration. type | The type of sink being defined. Determines the other values (if any) in instances of this type. Version 3.0 of the Linux Diagnostic Extension supports two sink types: EventHub, and JsonBlob. #### The EventHub sink ```json "sink": [ { "name": "sinkname", "type": "EventHub", "sasUrl": "https SAS URL" }, ... ] ``` The "sasURL" entry contains the full URL, including SAS token, for the EventHub endpoint to which data should be published. The SAS URL should be built using the EventHub endpoint (policy-level) shared key, not the root-level shared key for the entire EventHub subscription. Event Hubs SAS tokens are different from Storage SAS tokens; details can be found [on this web page](https://docs.microsoft.com/en-us/rest/api/eventhub/generate-sas-token). #### The JsonBlob sink ```json "sink": [ { "name": "sinkname", "type": "JsonBlob" }, ... ] ``` Data directed to a JsonBlob sink will be stored in blobs in a container with the same name as the sink. The Azure storage rules for blob container names apply to the names of JsonBlob sinks: between 3 and 63 lower-case alphanumeric ASCII characters or dashes. Individual blobs will be created every hour for each instance of the extension writing to the container. The blobs will always contain a syntactically-valid JSON object; new entries are added atomically. ## Public settings This structure contains various blocks of settings which control the information collected by the extension. ```json { "mdsdHttpProxy" : "", "ladCfg": { ... }, "perfCfg": { ... }, "fileLogs": { ... } } ``` Element | Value ------- | ----- mdsdHttpProxy | (optional) Same as in the Private Settings (see above). The public value is overridden by the private value, if set. If the proxy setting contains a secret (like a password), it shouldn't be specified here, but should be specified in the Private Settings. The remaining elements are described in detail, below. ### ladCfg ```json "ladCfg": { "diagnosticMonitorConfiguration": { "eventVolume": "Medium", "metrics": { ... }, "performanceCounters": { ... }, "syslogEvents": { ... } }, "sampleRateInSeconds": 15 } ``` Controls the gathering of metrics and logs for delivery to the Azure Metrics service and to other data destinations ("sinks"). All settings in this section, with the exception of eventVolume, can be controlled via the Azure portal as well as through PowerShell, CLI, or template. The Azure Metrics service requires metrics to be stored in a very particular Azure storage table. Similarly, log events must be stored in a different, but also very particular, table. All instances of the diagnostic extension configured (via Private Config) to use the same storage account name and endpoint will add their metrics and logs to the same table. If too many VMs are writing to the same table partition, Azure can throttle writes to that partition. The eventVolume setting changes how partition keys are constructed so that, across all instances of the extension writing to the same table, entries are spread across 1, 10, or 100 different partitions. Element | Value ------- | ----- eventVolume | Controls the number of partitions created within the storage table. Must be one of "Large", "Medium", or "Small". sampleRateInSeconds | The default interval between collection of raw (unaggregated) metrics. The smallest supported sample rate is 15 seconds. #### metrics ```json "metrics": { "resourceId": "/subscriptions/...", "metricAggregation" : [ { "scheduledTransferPeriod" : "PT1H" }, { "scheduledTransferPeriod" : "PT5M" } ] } ``` Samples of the metrics specified in the performanceCounters section are periodically collected. Those raw samples are aggregated to produce mean, minimum, maximum, and last-collected values, along with the count of raw samples used to compute the aggregate. If multiple scheduledTransferPeriod frequencies appear (as in the example), each aggregation is computed independently over the specified interval. The name of the storage table to which aggregated metrics are written (and from which Azure Metrics reads data) is based, in part, on the transfer period of the aggregated metrics stored within it. Element | Value ------- | ----- resourceId | The ARM resource ID of the VM or of the VM Scale Set to which the VM belongs. This setting must be also specified if any JsonBlob sink is used in the configuration. scheduledTransferPeriod | The frequency at which aggregate metrics are to be computed and transferred to Azure Metrics, expressed as an IS 8601 time interval. The smallest transfer period is 60 seconds, i.e. PT60S or PT1M. Samples of the metrics specified in the performanceCounters section are collected every 15 seconds or at the sample rate explicitly defined for the counter. If multiple scheduledTransferPeriod frequencies appear (as in the example), each aggregation is computed independently. The name of the storage table to which aggregated metrics are written (and from which Azure Metrics reads data) is based, in part, on the transfer period of the aggregated metrics stored within it. #### performanceCounters ```json "performanceCounters": { "sinks": "", "performanceCounterConfiguration": [ { "type": "builtin", "class": "Processor", "counter": "PercentIdleTime", "counterSpecifier": "/builtin/Processor/PercentIdleTime", "condition": "IsAggregate=TRUE", "sampleRate": "PT15S", "unit": "Percent", "annotation": [ { "displayName" : "Aggregate CPU %idle time", "locale" : "en-us" } ], }, ] } ``` Element | Value ------- | ----- sinks | A comma-separated list of names of sinks (as defined in the sinksConfig section of the Private configuration file) to which aggregated metric results should be published. All aggregated metrics will be published to each listed sink. Example: "EHsink1,myjsonsink" type | Identifies the actual provider of the metric. class | Together with "counter", identifies the specific metric within the provider's namespace. counter | Together with "class", identifies the specific metric within the provider's namespace. counterSpecifier | Identifies the specific metric within the Azure Metrics namespace. condition | Selects a specific instance of the object to which the metric applies or selects the aggregation across all instances of that object. See the metric definitions (below) for more information. sampleRate | IS 8601 interval which sets the rate at which raw samples for this metric are collected. If not set, the collection interval is set by the value of sampleRateInSeconds (see "ladCfg"). The shortest supported sample rate is 15 seconds, i.e. PT15S. unit | Should be one of these strings: "Count", "Bytes", "Seconds", "Percent", "CountPerSecond", "BytesPerSecond", "Millisecond". Defines the unit for the metric. The consumer of the collected data will expect the data LAD collects to match this unit. LAD ignores this field. displayName | The label (in the language specified by the associated locale setting) to be attached to this data in Azure Metrics. LAD ignores this field. #### syslogEvents ```json "syslogEvents": { "sinks": "", "syslogEventConfiguration": { "facilityName1": "minSeverity", "facilityName2": "minSeverity", ... } } ``` The syslogEventConfiguration collection has one entry for each syslog facility of interest. Setting a minSeverity of "NONE" for a particular facility behaves exactly as if that facility did not appear in the element at all; no events from that facility are captured. Element | Value ------- | ----- sinks | A comma-separated list of names of sinks to which individual log events should be published. All log events matching the restrictions in syslogEventConfiguration will be published to each listed sink. Example: "EHforsyslog" facilityName | A syslog facility name (e.g. "LOG\_USER" or "LOG\_LOCAL0"). See the "facility" section of the [syslog man page](http://man7.org/linux/man-pages/man3/syslog.3.html) for the full list. minSeverity | A syslog severity level (e.g. "LOG\_ERR" or "LOG\_INFO"). See the "level" section of the [syslog man page](http://man7.org/linux/man-pages/man3/syslog.3.html) for the full list. The extension will capture events sent to the facility at or above the specified level. ### perfCfg Controls execution of arbitrary [OMI](https://github.com/Microsoft/omi) queries. ```json "perfCfg": [ { "namespace": "root/scx", "query": "SELECT PercentAvailableMemory, PercentUsedSwap FROM SCX_MemoryStatisticalInformation", "table": "LinuxOldMemory", "frequency": 300, "sinks": "" } ] ``` Element | Value ------- | ----- namespace | (optional) The OMI namespace within which the query should be executed. If unspecified, the default value is "root/scx", implemented by the [System Center Cross-platform Providers](http://scx.codeplex.com/wikipage?title=xplatproviders&referringTitle=Documentation). query | The OMI query to be executed. table | (optional) The Azure storage table, in the designated storage account (see above) into which the results of the query will be placed. frequency | (optional) The number of seconds between execution of the query. Default value is 300 (5 minutes); minimum value is 15 seconds. sinks | (optional) A comma-separated list of names of additional sinks to which raw sample metric results should be published. No aggregation of these raw samples is computed by the extension or by Azure Metrics. Either "table" or "sinks", or both, must be specified. ### fileLogs Controls the capture of log files by rsyslogd or syslog-ng. As new text lines are written to the file, rsyslogd/syslog-ng captures them and passes them to the diagnostic extension, which in turn writes them as table rows or to the specified sinks (JsonBlob or EventHub). ```json "fileLogs": [ { "file": "/var/log/mydaemonlog", "table": "MyDaemonEvents", "sinks": "" } ] ``` Element | Value ------- | ----- file | The full pathname of the log file to be watched and captured. The pathname must name a single file; it cannot name a directory or contain wildcards. table | (optional) The Azure storage table, in the designated storage account (see above), into which new lines from the "tail" of the file will be placed. sinks | (optional) A comma-separated list of names of additional sinks to which log lines should be published. Either "table" or "sinks", or both, must be specified. ## Metrics supported by "builtin" The "builtin" metric provider is a source of metrics most interesting to a broad set of users. These metrics fall into five broad classes: * Processor * Memory * Network * Filesystem * Disk The available metrics are described in greater detail in the following sections. ### Builtin metrics for the Processor class The Processor class of metrics provides information about processor usage in the VM. When aggregating percentages, the result is the average across all CPUs. For example, given a VM with two cores, if one core was 100% busy for a given aggregation window and the other core was 100% idle, the reported PercentIdleTime would be 50; if each core was 50% busy for the same period, the reported result would also be 50. In a four core system, with one core 100% busy and the others completely idle, the reported PercentIdleTime would be 75. counter | Meaning ------- | ------- PercentIdleTime | Percentage of time during the aggregation window that processors were executing the kernel idle loop PercentProcessorTime | Percentage of time executing a non-idle thread PercentIOWaitTime | Percentage of time waiting for IO operations to complete PercentInterruptTime | Percentage of time executing hardware/software interrupts and DPCs (deferred procedure calls) PercentUserTime | Of non-idle time during the aggregation window, the percentage of time spent in user more at normal priority PercentNiceTime | Of non-idle time, the percentage spent at lowered (nice) priority PercentPrivilegedTime | Of non-idle time, the percentage spent in privileged (kernel) mode The first four counters should sum to 100%. The last three counters also sum to 100%; they subdivide the sum of PercentProcessorTime, PercentIOWaitTime, and PercentInterruptTime. To obtain a single metric aggregated across all processors, set "condition" to "IsAggregate=TRUE". To obtain a metric for a specific processor, set "condition" to "Name=\\"*nn*\\"" where *nn* is the logical processor number as known to the operating system, typically in the range 0..*n-1*. ### Builtin metrics for the Memory class The Memory class of metrics provide information about memory utilization, paging, and swapping. counter | Meaning ------- | ------- AvailableMemory | Available physical memory in MiB PercentAvailableMemory | Available physical memory as a percent of total memory UsedMemory | In-use physical memory (MiB) PercentUsedMemory | In-use physical memory as a percent of total memory PagesPerSec | Total paging (read/write) PagesReadPerSec | Pages read from backing store (pagefile, program file, mapped file, etc) PagesWrittenPerSec | Pages written to backing store (pagefile, mapped file, etc) AvailableSwap | Unused swap space (MiB) PercentAvailableSwap | Unused swap space as a percentage of total swap UsedSwap | In-use swap space (MiB) PercentUsedSwap | In-use swap space as a percentage of total swap This family of metrics has only a single instance; the "condition" attribute has no useful settings and should be omitted. ### Builtin metrics for the Network class The Network class of metrics provide information about network activity, aggregated across all network devices (eth0, eth1, etc.) since boot. Bandwidth information is not directly available; it is best retrieved from host metrics rather than from within the guest. counter | Meaning ------- | ------- BytesTransmitted | Total bytes sent since boot BytesReceived | Total bytes received since boot BytesTotal | Total bytes sent or received since boot PacketsTransmitted | Total packets sent since boot PacketsReceived | Total packets received since boot TotalRxErrors | Number of receive errors since boot TotalTxErrors | Number of transmit errors since boot TotalCollisions | Number of collisions reported by the network ports since boot This family of metrics has only a single instance; the "condition" attribute has no useful settings and should be omitted. ### Builtin metrics for the Filesystem class The Filesystem class of metrics provide information about filesystem usage. Absolute and percentage values are reported as they'd be displayed to an ordinary user (not root). counter | Meaning ------- | ------- FreeSpace | Available disk space in bytes UsedSpace | Used disk space in bytes PercentFreeSpace | Percentage free space PercentUsedSpace | Percentage used space PercentFreeInodes | Percentage of unused inodes PercentUsedInodes | Percentage of allocated (in use) inodes summed across all filesystems BytesReadPerSecond | Bytes read per second BytesWrittenPerSecond | Bytes written per second BytesPerSecond | Bytes read or written per second ReadsPerSecond | Read operations per second WritesPerSecond | Write operations per second TransfersPerSecond | Read or write operations per second Aggregated values across all file systems can be obtained by setting "condition" to "IsAggregate=True". Values for a specific mounted file system can be obtained by setting "condition" to 'Name="*mountpoint*"' where *mountpoint* is the path at which the filesystem was mounted ("/", "/mnt", etc.). ### Builtin metrics for the Disk class The Disk class of metrics provide information about disk device usage. These statistics apply to the drive itself without regard to the number of file systems that may exist on the device; if there are multiple file systems on a device, the counters for that device are, effectively, aggregated across of them. counter | Meaning ------- | ------- ReadsPerSecond | Read operations per second WritesPerSecond | Write operations per second TransfersPerSecond | Total operations per second AverageReadTime | Average seconds per read operation AverageWriteTime | Average seconds per write operation AverageTransferTime | Average seconds per operation AverageDiskQueueLength | Average number of queued disk operations ReadBytesPerSecond | Number of bytes read per second WriteBytesPerSecond | Number of bytes written per second BytesPerSecond | Number of bytes read or written per second Aggregated values across all disks can be obtained by setting "condition" to "IsAggregate=True". Values for a specific disk device can be obtained by setting "condition" to "Name=\\"*devicename*\\"" where *devicename* is the path of the device file for the disk ("/dev/sda1", "/dev/sdb1", etc.). ## Installing and configuring LAD 3.0 via CLI Assuming your protected settings are in the file PrivateConfig.json and your public configuration information is in PublicConfig.json, run this command: > azure vm extension set *resource_group_name* *vm_name* LinuxDiagnostic Microsoft.Azure.Diagnostics '3.*' --private-config-path PrivateConfig.json --public-config-path PublicConfig.json Please note that the above command assumes you are in the Azure Resource Management mode (arm) of the Azure CLI and applies only to the Azure ARM VMs, not to any classic Azure VMs. For classic (or ASM, Azure Service Management) VMs, you'll need to set the CLI mode to "asm" (run `azure config mode asm`) before running the above command, and you should also omit the resource group name in the command (there is no notion of resource groups in ASM). For more information on different modes of Azure CLI and how to use them, please refer to related documentation like [this](https://docs.microsoft.com/en-us/azure/xplat-cli-connect). ## An example LAD 3.0 configuration Based on the above definitions, here's a sample LAD 3.0 extension configuration with some explanation. Please note that in order to apply this sample to your case, you should use your own storage account name, account SAS token, and EventHubs SAS tokens. First, the following private settings (that should be saved in a file as PrivateConfig.json, if you want to use the above Azure CLI command to enable the extension) will configure a storage account, its account SAS token, and various sinks (JsonBlob or EventHubs with SAS tokens): ```json { "storageAccountName": "yourdiagstgacct", "storageAccountSasToken": "sv=xxxx-xx-xx&ss=bt&srt=co&sp=wlacu&st=yyyy-yy-yyT21%3A22%3A00Z&se=zzzz-zz-zzT21%3A22%3A00Z&sig=fake_signature", "sinksConfig": { "sink": [ { "name": "SyslogJsonBlob", "type": "JsonBlob" }, { "name": "FilelogJsonBlob", "type": "JsonBlob" }, { "name": "LinuxCpuJsonBlob", "type": "JsonBlob" }, { "name": "WADMetricJsonBlob", "type": "JsonBlob" }, { "name": "LinuxCpuEventHub", "type": "EventHub", "sasURL": "https://youreventhubnamespace.servicebus.windows.net/youreventhubpublisher?sr=https%3a%2f%2fyoureventhubnamespace.servicebus.windows.net%2fyoureventhubpublisher%2f&sig=fake_signature&se=1808096361&skn=yourehpolicy" }, { "name": "WADMetricEventHub", "type": "EventHub", "sasURL": "https://youreventhubnamespace.servicebus.windows.net/youreventhubpublisher?sr=https%3a%2f%2fyoureventhubnamespace.servicebus.windows.net%2fyoureventhubpublisher%2f&sig=yourehpolicy&skn=yourehpolicy" }, { "name": "LoggingEventHub", "type": "EventHub", "sasURL": "https://youreventhubnamespace.servicebus.windows.net/youreventhubpublisher?sr=https%3a%2f%2fyoureventhubnamespace.servicebus.windows.net%2fyoureventhubpublisher%2f&sig=yourehpolicy&se=1808096361&skn=yourehpolicy" } ] } } ``` Then the following public settings (that should be saved in a file as PublicConfig.json for the Azure CLI command above) will do the following: * Uploads percent-processor-time and used-disk-space to Azure Metric service table (this will allow you to view these metrics in the Azure Portal), and your EventHub (as specified in your sink `WADMetricEventHub`) and your Azure Blob storage (container name is `wadmetricjsonblob`). * Uploads messages from syslog facility "user" and severity "info" or above to your Azure Table storage (always on by default, and the Azure Table name is `LinuxSyslog*`), your Azure Blob storage (container name is `syslogjsonblob*`), and your EventHubs publisher (as specified in your sink name `LoggingEventHub`). * Uploads raw OMI query results (PercentProcessorTime and PercentIdleTime) to your Azure Table storage (table name is `LinuxCpu*`), your Azure Blob storage (container name is `linuxcpujsonblob*`) and your EventHubs publisher (as specified in your sink name `LinuxCpuEventHub`). * Uploads appended lines in file `/var/log/myladtestlog` to your Azure Table storage (table name is MyLadTestLog\*), your Azure Blob storage (container name is `filelogjsonblob*`), and to your EventHubs publisher (as specified in your sink name `LoggingEventHub`). ```json { "StorageAccount": "yourdiagstgacct", "sampleRateInSeconds": 15, "ladCfg": { "diagnosticMonitorConfiguration": { "performanceCounters": { "sinks": "WADMetricEventHub,WADMetricJsonBlob", "performanceCounterConfiguration": [ { "unit": "Percent", "type": "builtin", "counter": "PercentProcessorTime", "counterSpecifier": "/builtin/Processor/PercentProcessorTime", "annotation": [ { "locale": "en-us", "displayName": "Aggregate CPU %utilization" } ], "condition": "IsAggregate=TRUE", "class": "Processor" }, { "unit": "Bytes", "type": "builtin", "counter": "UsedSpace", "counterSpecifier": "/builtin/FileSystem/UsedSpace", "annotation": [ { "locale": "en-us", "displayName": "Used disk space on /" } ], "condition": "Name=\"/\"", "class": "Filesystem" } ] }, "metrics": { "metricAggregation": [ { "scheduledTransferPeriod": "PT1H" }, { "scheduledTransferPeriod": "PT1M" } ], "resourceId": "/subscriptions/your_azure_subscription_id/resourceGroups/your_resource_group_name/providers/Microsoft.Compute/virtualMachines/your_vm_name" }, "eventVolume": "Large", "syslogEvents": { "sinks": "SyslogJsonBlob,LoggingEventHub", "syslogEventConfiguration": { "LOG_USER": "LOG_INFO" } } } }, "perfCfg": [ { "query": "SELECT PercentProcessorTime, PercentIdleTime FROM SCX_ProcessorStatisticalInformation WHERE Name='_TOTAL'", "table": "LinuxCpu", "frequency": 60, "sinks": "LinuxCpuJsonBlob,LinuxCpuEventHub" } ], "fileLogs": [ { "file": "/var/log/myladtestlog", "table": "MyLadTestLog", "sinks": "FilelogJsonBlob,LoggingEventHub" } ] } ``` Please note that you must provide the correct `resourceId` in order for the Azure Metrics service to display your `performanceCounters` data correctly in the Azure Portal charts. The resource ID is also used by JsonBlob sinks as well when forming the names of blobs. ## Configuring and enabling the extension for Azure Portal metrics charting experiences Here's a sample configuration (provided in the `wget` URL below), and installation instructions, that will configure LAD 3.0 to capture and store exactly the same metrics (actually file system metrics are newly added in LAD 3.0) as were provided by LAD 2.3 for Azure Portal VM metrics charting experiences (and default syslog collection as enabled on LAD 2.3). You should consider this just an example; you'll want to modify the metrics to suit your own needs. If you'd like to proceed, please execute the following commands on your Azure CLI terminal after [installing Azure CLI 2.0](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli) and wget (run `sudo apt-get install wget` on a Debian-based Linux disro or `sudo yum install wget` on a Redhat-based Linux distro). Also make sure to provide correct values for your Azure VM diagnostic paremeters in the first 3 lines. ```bash # Set your Azure VM diagnostic parameters correctly below my_resource_group= my_linux_vm= my_diagnostic_storage_account= # Should login to Azure first before anything else az login # Get VM resource ID as well, and replace storage account name and resource ID in the public settings. my_vm_resource_id=$(az vm show -g $my_resource_group -n $my_linux_vm --query "id" -o tsv) wget https://raw.githubusercontent.com/Azure/azure-linux-extensions/master/Diagnostic/tests/lad_2_3_compatible_portal_pub_settings.json -O portal_public_settings.json sed -i "s#__DIAGNOSTIC_STORAGE_ACCOUNT__#$my_diagnostic_storage_account#g" portal_public_settings.json sed -i "s#__VM_RESOURCE_ID__#$my_vm_resource_id#g" portal_public_settings.json # Set protected settings (storage account SAS token) my_diagnostic_storage_account_sastoken=$(az storage account generate-sas --account-name $my_diagnostic_storage_account --expiry 9999-12-31T23:59Z --permissions wlacu --resource-types co --services bt -o tsv) my_lad_protected_settings="{'storageAccountName': '$my_diagnostic_storage_account', 'storageAccountSasToken': '$my_diagnostic_storage_account_sastoken'}" # Finallly enable (set) the extension for the Portal metrics charts experience az vm extension set --publisher Microsoft.Azure.Diagnostics --name LinuxDiagnostic --version 3.0 --resource-group $my_resource_group --vm-name $my_linux_vm --protected-settings "${my_lad_protected_settings}" --settings portal_public_settings.json # Done ``` The URL and its contents are subject to change. You should download a copy of the portal settings JSON file and customize it for your needs; any templates or automation you construct should use your own copy, rather than downloading that URL each time. ### Important notes on customizing the downloaded `portal_public_settings.json` After experimenting with the downloaded `portal_public_settings.json` configuration as is, you may want to customize it for your own fit. For example, you may want to remove the entire `syslogEvents` section of the downloaded `portal_public_settings.json` if you don't need to collect syslog events at all. You can also remove unneeded entries in the `performanceCounterConfiguration` section of the downloaded `portal_public_settings.json` if you are not interested in some metrics. However, you should not modify other settings without fully understanding what they are and how they work. Only recommended customization at this point is to remove unwanted metrics or syslog events, and possibly changing the `displayName` values for metrics of your interest. ### Important notes on upgrading to LAD 3.0 from LAD 2.3 **Please use a new/different storage account for LAD 3.0** if you are upgrading from LAD 2.3. As mentioned earlier, you should uninstall LAD 2.3 first in order to upgrade to LAD 3.0, and if you specify the same storage account for LAD 3.0 as used in LAD 2.3, the syslog events collection with the new LAD 3.0 may not work because of a small change in LAD 3.0's syslog Azure Table name. Therefore, you should use a new storage account for LAD 3.0 if you still want to collect syslog events. ## Review your data The performance and diagnostic data are stored in an Azure Storage table by default. Review [How to use Azure Table Storage from Ruby](../storage/storage-ruby-how-to-use-table-storage.md) to learn how to access the data in the storage table by using Azure Table Storage Ruby API. Note that Azure Storage APIs are available in many other languages and platforms. If you specified JsonBlob sinks for your LAD extension configuration, then the same storage account's blob containers will hold your performance and/or diagnostic data. You can consume the blob data using any Azure Blob Storage APIs. In addition, you can use following UI tools to access the data in Azure Storage: 1. [Microsoft Azure Storage Explorer](http://storageexplorer.com/) 1. Visual Studio Server Explorer. 1. [Azure Storage Explorer](https://azurestorageexplorer.codeplex.com/ "Azure Storage Explorer"). The following is a snapshot of a Microsoft Azure Storage Explorer session showing the generated Azure Storage tables and containers from a correctly configured LAD 3.0 extension on a test VM. Note that the snapshot doesn't match exactly with the sample LAD 3.0 configuration provided above. ![image](./media/virtual-machines-linux-diagnostic-extension-v3/stg_explorer.png) If you specified EventHubs sinks for your LAD extension configuraiton, then you'll want to consume the published EventHubs messages following related EventHubs documentation. You may want to start from [here](https://docs.microsoft.com/en-us/azure/event-hubs/event-hubs-what-is-event-hubs). ================================================ FILE: Diagnostic/watcherutil.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Linux Azure Diagnostic Extension (Current version is specified in manifest.xml) # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import subprocess import os import datetime import time import string import traceback class Watcher: """ A class that handles periodic monitoring activities that are requested for LAD to perform. The first such activity is to watch /etc/fstab and report (log to console) if there's anything wrong with that. There might be other such monitoring activities that will be added later. """ def __init__(self, hutil_error, hutil_log, log_to_console=False): """ Constructor. :param hutil_error: Error logging function (e.g., hutil.error). This is not a stream. :param hutil_log: Normal logging function (e.g., hutil.log). This is not a stream. :param log_to_console: Indicates whether to log any issues to /dev/console or not. """ # This is only for the /etc/fstab watcher feature. self._fstab_last_mod_time = os.path.getmtime('/etc/fstab') self._hutil_error = hutil_error self._hutil_log = hutil_log self._log_to_console = log_to_console self._imds_logger = None def _do_log_to_console_if_enabled(self, message): """ Write 'message' to console. Stolen from waagent LogToCon(). """ if self._log_to_console: try: with open('/dev/console', 'w') as console: message = filter(lambda x: x in string.printable, message) console.write(message.encode('ascii', 'ignore') + '\n') except IOError as e: self._hutil_error('Error writing to console. Exception={0}'.format(e)) def handle_fstab(self, ignore_time=False): """ Watches if /etc/fstab is modified and verifies if it's OK. Otherwise, report it in logs or to /dev/console. :param ignore_time: Disable the default logic of delaying /etc/fstab verification by 1 minute. This is to allow any test code to avoid waiting 1 minute unnecessarily. :return: None """ try_mount = False if ignore_time: try_mount = True else: current_mod_time = os.path.getmtime('/etc/fstab') current_mod_date_time = datetime.datetime.fromtimestamp(current_mod_time) # Only try to mount if it's been at least 1 minute since the # change to fstab was done, to prevent spewing out erroneous spew if (current_mod_time != self._fstab_last_mod_time and datetime.datetime.now() > current_mod_date_time + datetime.timedelta(minutes=1)): try_mount = True self._fstab_last_mod_time = current_mod_time ret = 0 if try_mount: ret = subprocess.call(['sudo', 'mount', '-a', '-vf']) if ret != 0: # There was an error running mount, so log error_msg = 'fstab modification failed mount validation. Please correct before reboot.' self._hutil_error(error_msg) self._do_log_to_console_if_enabled(error_msg) else: # No errors self._hutil_log('fstab modification passed mount validation') return ret def set_imds_logger(self, imds_logger): self._imds_logger = imds_logger def watch(self): """ Main loop performing various monitoring activities periodically. Currently iterates every 5 minutes, and other periodic activities might be added in the loop later. :return: None """ while True: # /etc/fstab watcher self.handle_fstab() # IMDS probe (only sporadically, inside the function) if self._imds_logger: try: self._imds_logger.log_imds_data_if_right_time() except Exception as e: self._hutil_error('ImdsLogger exception: {0}\nStacktrace: {1}'.format(e, traceback.format_exc())) # Sleep 5 minutes time.sleep(60 * 5) pass ================================================ FILE: LAD-AMA-Common/metrics_ext_utils/__init__.py ================================================ # Metrics Extension helper script for LAD/AMA ================================================ FILE: LAD-AMA-Common/metrics_ext_utils/metrics_common_utils.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import os def is_systemd(): """ Check if the system is using systemd """ return os.path.isdir("/run/systemd/system") def is_arc_installed(): """ Check if the system is an on prem machine running Arc """ # Using systemctl to check this since Arc only supports VM that have systemd check_arc = os.system("systemctl status himdsd 1>/dev/null 2>&1") return check_arc == 0 def get_arc_endpoint(): """ Find the endpoint for arc Hybrid IMDS """ endpoint_filepath = "/lib/systemd/system.conf.d/azcmagent.conf" with open(endpoint_filepath, "r") as f: data = f.read() endpoint = data.split("\"IMDS_ENDPOINT=")[1].split("\"\n")[0] return endpoint ================================================ FILE: LAD-AMA-Common/metrics_ext_utils/metrics_constants.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # This File contains constants used for Platform Metrics feature in LAD and Azure Monitor Extension metrics_extension_namespace = "Azure.VM.Linux.GuestMetrics" #AMA Constants ama_metrics_extension_bin = "/opt/microsoft/azuremonitoragent/bin/MetricsExtension" metrics_extension_service_name = "metrics-extension" metrics_extension_service_path = "/lib/systemd/system/metrics-extension.service" metrics_extension_service_path_usr_lib = "/usr/lib/systemd/system/metrics-extension.service" metrics_extension_service_path_etc = "/etc/systemd/system/metrics-extension.service" ama_telegraf_bin = "/opt/microsoft/azuremonitoragent/bin/telegraf" telegraf_service_name = "metrics-sourcer" telegraf_service_path = "/lib/systemd/system/metrics-sourcer.service" telegraf_service_path_usr_lib = "/usr/lib/systemd/system/metrics-sourcer.service" telegraf_service_path_etc = "/etc/systemd/system/metrics-sourcer.service" ama_metrics_extension_udp_port = "17659" #LAD Constants lad_metrics_extension_bin = "/usr/local/lad/bin/MetricsExtension" lad_metrics_extension_service_name = "metrics-extension-lad" lad_metrics_extension_service_path = "/lib/systemd/system/metrics-extension-lad.service" lad_metrics_extension_service_path_usr_lib = "/usr/lib/systemd/system/metrics-extension-lad.service" lad_telegraf_bin = "/usr/local/lad/bin/telegraf" lad_telegraf_service_name = "metrics-sourcer-lad" lad_telegraf_service_path = "/lib/systemd/system/metrics-sourcer-lad.service" lad_telegraf_service_path_usr_lib = "/usr/lib/systemd/system/metrics-sourcer-lad.service" lad_metrics_extension_udp_port = "13459" lad_metrics_extension_influx_udp_url = "udp://127.0.0.1:" + lad_metrics_extension_udp_port telegraf_influx_url = "unix:///var/run/mdsd/lad_mdsd_influx.socket" ================================================ FILE: LAD-AMA-Common/metrics_ext_utils/metrics_ext_handler.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import platform import sys import json import os from shutil import copyfile, rmtree import stat import grp import pwd import filecmp import metrics_ext_utils.metrics_constants as metrics_constants import subprocess import time import signal import metrics_ext_utils.metrics_common_utils as metrics_utils try: import urllib.request as urllib # Python 3+ except ImportError: import urllib2 as urllib # Python 2 try: import urllib.error as urlerror # Python 3+ except ImportError: import urllib2 as urlerror # Python 2 try: from urllib.parse import urlparse # Python 3+ except ImportError: from urlparse import urlparse # Python 2 # Cloud Environments PublicCloudName = "azurepubliccloud" FairfaxCloudName = "azureusgovernmentcloud" MooncakeCloudName = "azurechinacloud" USNatCloudName = "usnat" # EX USSecCloudName = "ussec" # RX ArcACloudName = "azurestackcloud" DefaultCloudName = PublicCloudName # Fallback ARMDomainMap = { PublicCloudName: "management.azure.com", FairfaxCloudName: "management.usgovcloudapi.net", MooncakeCloudName: "management.chinacloudapi.cn", USNatCloudName: "management.azure.eaglex.ic.gov", USSecCloudName: "management.azure.microsoft.scloud", ArcACloudName: "armmanagement.autonomous.cloud.private" } def is_running(is_lad): """ This method is used to check if metrics binary is currently running on the system or not. In order to check whether it needs to be restarted from the watcher daemon """ if is_lad: metrics_bin = metrics_constants.lad_metrics_extension_bin else: metrics_bin = metrics_constants.ama_metrics_extension_bin proc = subprocess.Popen(["ps aux | grep MetricsExtension | grep -v grep"], stdout=subprocess.PIPE, shell=True) output = proc.communicate()[0] if metrics_bin in output.decode('utf-8', 'ignore'): return True else: return False def stop_metrics_service(is_lad): """ Stop the metrics service if VM is using is systemd, otherwise check if the pid_file exists, and if the pid belongs to the MetricsExtension process, if yes, then kill the process This method is called before remove_metrics_service by the main extension code :param is_lad: boolean whether the extension is LAD or not (AMA) """ if is_lad: metrics_ext_bin = metrics_constants.lad_metrics_extension_bin else: metrics_ext_bin = metrics_constants.ama_metrics_extension_bin # If the VM has systemd, then we will use that to stop if metrics_utils.is_systemd(): code = 1 metrics_service_path = get_metrics_extension_service_path(is_lad) metrics_service_name = get_metrics_extension_service_name(is_lad) if os.path.isfile(metrics_service_path): code = os.system("systemctl stop {0}".format(metrics_service_name)) else: return False, "Metrics Extension service file does not exist. Failed to stop ME service: {0}.service.".format(metrics_service_name) if code != 0: return False, "Unable to stop Metrics Extension service: {0}. Failed with code {1}".format(metrics_service_name, code) else: #This VM does not have systemd, So we will use the pid from the last ran metrics process and terminate it _, configFolder = get_handler_vars() metrics_conf_dir = configFolder + "/metrics_configs/" metrics_pid_path = metrics_conf_dir + "metrics_pid.txt" if os.path.isfile(metrics_pid_path): pid = "" with open(metrics_pid_path, "r") as f: pid = f.read() if pid != "": # Check if the process running is indeed MetricsExtension, ignore if the process output doesn't contain MetricsExtension proc = subprocess.Popen(["ps -o cmd= {0}".format(pid)], stdout=subprocess.PIPE, shell=True) output = proc.communicate()[0] if metrics_ext_bin in output.decode('utf-8', 'ignore'): os.kill(int(pid), signal.SIGKILL) else: return False, "Found a different process running with PID {0}. Failed to stop MetricsExtension.".format(pid) else: return False, "No pid found for a currently running Metrics Extension process in {0}. Failed to stop Metrics Extension.".format(metrics_pid_path) else: return False, "File containing the pid for the running Metrics Extension process at {0} does not exit. Failed to stop Metrics Extension".format(metrics_pid_path) return True, "Successfully stopped metrics-extension service" def remove_metrics_service(is_lad): """ Remove the metrics service if the VM is using systemd as well as the MetricsExtension Binary This method is called after stop_metrics_service by the main extension code during Extension uninstall :param is_lad: boolean whether the extension is LAD or not (AMA) """ metrics_service_path = get_metrics_extension_service_path(is_lad) if os.path.isfile(metrics_service_path): code = os.remove(metrics_service_path) if is_lad: metrics_ext_bin = metrics_constants.lad_metrics_extension_bin else: metrics_ext_bin = metrics_constants.ama_metrics_extension_bin # Checking To see if the files were successfully removed, since os.remove doesn't return an error code if os.path.isfile(metrics_ext_bin): remove_code = os.remove(metrics_ext_bin) return True, "Successfully removed metrics-extensions service and MetricsExtension binary." def generate_Arc_MSI_token(resource = "https://ingestion.monitor.azure.com/"): """ This method is used to query the Hyrbid metdadata service of Arc to get the MSI Auth token for the VM and write it to the ME config location This is called from the main extension code after config setup is complete """ _, configFolder = get_handler_vars() me_config_dir = configFolder + "/metrics_configs/" me_auth_file_path = me_config_dir + "AuthToken-MSI.json" expiry_epoch_time = "" log_messages = "" retries = 1 max_retries = 3 sleep_time = 5 if not os.path.exists(me_config_dir): log_messages += "Metrics extension config directory - {0} does not exist. Failed to generate MSI auth token fo ME.\n".format(me_config_dir) return False, expiry_epoch_time, log_messages try: data = None while retries <= max_retries: arc_endpoint = metrics_utils.get_arc_endpoint() try: msiauthurl = arc_endpoint + "/metadata/identity/oauth2/token?api-version=2019-11-01&resource=" + resource req = urllib.Request(msiauthurl, headers={'Metadata':'true'}) res = urllib.urlopen(req) except: # The above request is expected to fail and add a key to the path authkey_dir = "/var/opt/azcmagent/tokens/" if not os.path.exists(authkey_dir): log_messages += "Unable to find the auth key file at {0} returned from the arc msi auth request.".format(authkey_dir) return False, expiry_epoch_time, log_messages keys_dir = [] for filename in os.listdir(authkey_dir): keys_dir.append(filename) authkey_path = authkey_dir + keys_dir[-1] auth = "basic " with open(authkey_path, "r") as f: key = f.read() auth += key req = urllib.Request(msiauthurl, headers={'Metadata':'true', 'authorization':auth}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) if not data or "access_token" not in data: retries += 1 else: break log_messages += "Failed to fetch MSI Auth url. Retrying in {2} seconds. Retry Count - {0} out of Mmax Retries - {1}\n".format(retries, max_retries, sleep_time) time.sleep(sleep_time) if retries > max_retries: log_messages += "Unable to generate a valid MSI auth token at {0}.\n".format(me_auth_file_path) return False, expiry_epoch_time, log_messages with open(me_auth_file_path, "w") as f: f.write(json.dumps(data)) if "expires_on" in data: expiry_epoch_time = data["expires_on"] else: log_messages += "Error parsing the msi token at {0} for the token expiry time. Failed to generate the correct token\n".format(me_auth_file_path) return False, expiry_epoch_time, log_messages except Exception as e: log_messages += "Failed to get msi auth token. Please check if VM's system assigned Identity is enabled Failed with error {0}\n".format(e) return False, expiry_epoch_time, log_messages return True, expiry_epoch_time, log_messages def generate_MSI_token(identifier_name = '', identifier_value = '', is_lad = True): """ This method is used to query the metdadata service to get the MSI Auth token for the VM and write it to the ME config location This is called from the main extension code after config setup is complete """ if metrics_utils.is_arc_installed(): _, _, _, az_environment, _ = get_imds_values(is_lad) if az_environment.lower() == ArcACloudName: return generate_Arc_MSI_token("https://monitoring.azs") return generate_Arc_MSI_token() else: _, configFolder = get_handler_vars() me_config_dir = configFolder + "/metrics_configs/" me_auth_file_path = me_config_dir + "AuthToken-MSI.json" expiry_epoch_time = "" log_messages = "" retries = 1 max_retries = 3 sleep_time = 5 if not os.path.exists(me_config_dir): log_messages += "Metrics extension config directory - {0} does not exist. Failed to generate MSI auth token for ME.\n".format(me_config_dir) return False, expiry_epoch_time, log_messages try: data = None while retries <= max_retries: msiauthurl = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://ingestion.monitor.azure.com/" if identifier_name and identifier_value: msiauthurl += '&{0}={1}'.format(identifier_name, identifier_value) req = urllib.Request(msiauthurl, headers={'Metadata':'true', 'Content-Type':'application/json'}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) if not data or "access_token" not in data: retries += 1 else: break log_messages += "Failed to fetch MSI Auth url. Retrying in {2} seconds. Retry Count - {0} out of Mmax Retries - {1}\n".format(retries, max_retries, sleep_time) time.sleep(sleep_time) if retries > max_retries: log_messages += "Unable to generate a valid MSI auth token at {0}.\n".format(me_auth_file_path) return False, expiry_epoch_time, log_messages with open(me_auth_file_path, "w") as f: f.write(json.dumps(data)) if "expires_on" in data: expiry_epoch_time = data["expires_on"] else: log_messages += "Error parsing the MSI token at {0} for the token expiry time. Failed to generate the correct token\n".format(me_auth_file_path) return False, expiry_epoch_time, log_messages except Exception as e: log_messages += "Failed to get MSI auth token. Please check if the VM's system assigned identity is enabled or the user assigned identity " log_messages += "passed in the extension settings exists and is assigned to this VM. Failed with error {0}\n".format(e) return False, expiry_epoch_time, log_messages return True, expiry_epoch_time, log_messages def get_ArcA_MSI_token(resource = "https://monitoring.azs"): """ This method is used to query the Hyrbid metdadata service of ArcA to get the MSI Auth token for the VM """ token_string = "" log_messages = "" retries = 1 max_retries = 3 sleep_time = 5 try: data = None while retries <= max_retries: arc_endpoint = metrics_utils.get_arc_endpoint() try: msiauthurl = arc_endpoint + "/metadata/identity/oauth2/token?api-version=2019-11-01&resource=" + resource req = urllib.Request(msiauthurl, headers={'Metadata':'true'}) res = urllib.urlopen(req) except: # The above request is expected to fail and add a key to the path authkey_dir = "/var/opt/azcmagent/tokens/" if not os.path.exists(authkey_dir): log_messages += "Unable to find the auth key file at {0} returned from the arc msi auth request.".format(authkey_dir) return False, token_string, log_messages keys_dir = [] for filename in os.listdir(authkey_dir): keys_dir.append(filename) authkey_path = authkey_dir + keys_dir[-1] auth = "basic " with open(authkey_path, "r") as f: key = f.read() auth += key req = urllib.Request(msiauthurl, headers={'Metadata':'true', 'authorization':auth}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) if not data or "access_token" not in data: retries += 1 else: break log_messages += "Failed to fetch MSI Auth url. Retrying in {2} seconds. Retry Count - {0} out of Mmax Retries - {1}\n".format(retries, max_retries, sleep_time) time.sleep(sleep_time) if retries > max_retries: log_messages += "Unable to fetch a valid MSI auth token for {0}.\n".format(resource) return False, token_string, log_messages token_string = data["access_token"] except Exception as e: log_messages += "Failed to get msi auth token. Please check if VM's system assigned Identity is enabled Failed with error {0}\n".format(e) return False, token_string, log_messages return True, token_string, log_messages def setup_me_service(is_lad, configFolder, monitoringAccount, metrics_ext_bin, me_influx_port, managed_identity="sai", HUtilObj=None): """ Setup the metrics service if VM is using systemd :param configFolder: Path for the config folder for metrics extension :param monitoringAccount: Monitoring Account name that ME will upload data to :param metrics_ext_bin: Path for the binary for metrics extension :param me_influx_port: Influxdb port that metrics extension will listen on """ me_service_path = get_metrics_extension_service_path(is_lad) me_service_template_path = os.getcwd() + "/services/metrics-extension.service" daemon_reload_status = 1 if not os.path.exists(configFolder): raise Exception("Metrics extension config directory does not exist. Failed to set up ME service.") me_influx_socket_path = configFolder + "/mdm_influxdb.socket" if os.path.isfile(me_service_template_path): copyfile(me_service_template_path, me_service_path) if os.path.isfile(me_service_path): os.system(r"sed -i 's+%ME_BIN%+{1}+' {0}".format(me_service_path, metrics_ext_bin)) os.system(r"sed -i 's+%ME_INFLUX_PORT%+{1}+' {0}".format(me_service_path, me_influx_port)) os.system(r"sed -i 's+%ME_DATA_DIRECTORY%+{1}+' {0}".format(me_service_path, configFolder)) os.system(r"sed -i 's+%ME_MONITORING_ACCOUNT%+{1}+' {0}".format(me_service_path, monitoringAccount)) os.system(r"sed -i 's+%ME_MANAGED_IDENTITY%+{1}+' {0}".format(me_service_path, managed_identity)) os.system(r"sed -i 's+%ME_INFLUX_SOCKET_FILE_PATH%+{1}+' {0}".format(me_service_path, me_influx_socket_path)) daemon_reload_status = os.system("systemctl daemon-reload") if daemon_reload_status != 0: message = "Unable to reload systemd after ME service file change. Failed to set up ME service. Check system for hardening. Exit code:" + str(daemon_reload_status) if HUtilObj is not None: HUtilObj.log(message) else: print('Info: {0}'.format(message)) else: raise Exception("Unable to copy Metrics extension service file to {0}. Failed to set up ME service.".format(me_service_path)) else: raise Exception("Metrics extension service template file does not exist at {0}. Failed to set up ME service.".format(me_service_template_path)) return True def start_metrics_cmv2(): """ Start the metrics service in CMv2 mode """ # Re using the code to grab the config directories and imds values because start will be called from Enable process outside this script log_messages = "" metrics_ext_bin = metrics_constants.ama_metrics_extension_bin if not os.path.isfile(metrics_ext_bin): log_messages += "Metrics Extension binary does not exist. Failed to start ME service." return False, log_messages # If the VM has systemd, then we use that to start/stop metrics_service_name = get_metrics_extension_service_name(False) if metrics_utils.is_systemd(): service_restart_status = os.system("systemctl restart {0}".format(metrics_service_name)) if service_restart_status != 0: log_messages += "Unable to start {0} using systemctl. Failed to start ME service. Check system for hardening.".format(metrics_service_name) return False, log_messages else: return True, log_messages return False, log_messages def start_metrics(is_lad, managed_identity="sai"): """ Start the metrics service if VM is using is systemd, otherwise start the binary as a process and store the pid, to a file in the MetricsExtension config directory, This method is called after config setup is completed by the main extension code :param is_lad: boolean whether the extension is LAD or not (AMA) """ # Re using the code to grab the config directories and imds values because start will be called from Enable process outside this script log_messages = "" if is_lad: metrics_ext_bin = metrics_constants.lad_metrics_extension_bin else: metrics_ext_bin = metrics_constants.ama_metrics_extension_bin if not os.path.isfile(metrics_ext_bin): log_messages += "Metrics Extension binary does not exist. Failed to start ME service." return False, log_messages if is_lad: me_influx_port = metrics_constants.lad_metrics_extension_udp_port else: me_influx_port = metrics_constants.ama_metrics_extension_udp_port # If the VM has systemd, then we use that to start/stop metrics_service_name = get_metrics_extension_service_name(is_lad) if metrics_utils.is_systemd(): service_restart_status = os.system("systemctl restart {0}".format(metrics_service_name)) if service_restart_status != 0: log_messages += "Unable to start {0} using systemctl. Failed to start ME service. Check system for hardening.".format(metrics_service_name) return False, log_messages #Else start ME as a process and save the pid to a file so that we can terminate it while disabling/uninstalling else: _, configFolder = get_handler_vars() me_config_dir = configFolder + "/metrics_configs/" #query imds to get the subscription id az_resource_id, subscription_id, location, az_environment, data = get_imds_values(is_lad) if is_lad: monitoringAccount = "CUSTOMMETRIC_"+ subscription_id else: monitoringAccount = "CUSTOMMETRIC_"+ subscription_id + "_" + location metrics_pid_path = me_config_dir + "metrics_pid.txt" # If LAD, use ME startup arguments for LAD, otherwise use ME startup arguments for AMA if is_lad: binary_exec_command = "{0} -TokenSource MSI -Input influxdb_udp -InfluxDbHost 127.0.0.1 -InfluxDbUdpPort {1} -DataDirectory {2} -LocalControlChannel -MonitoringAccount {3} -LogLevel Error".format(metrics_ext_bin, me_influx_port, me_config_dir, monitoringAccount) else: log_messages += "MetricsExtension will not be started." return False, log_messages proc = subprocess.Popen(binary_exec_command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) time.sleep(3) #sleeping for 3 seconds before checking if the process is still running, to give it ample time to relay crash info p = proc.poll() if p is None: #Process is running successfully metrics_pid = proc.pid #write this pid to a file for future use with open(metrics_pid_path, "w+") as f: f.write(str(metrics_pid)) else: out, err = proc.communicate() log_messages += "Unable to run MetricsExtension binary as a process due to error - {0}. Failed to start MetricsExtension.".format(err) return False, log_messages return True, log_messages def create_metrics_extension_conf(az_resource_id, aad_url): """ Create the metrics extension config :param az_resource_id: Azure Resource ID for the VM :param aad_url: AAD auth url for the VM """ conf_json = '''{ "timeToTerminateInMs": 4000, "configurationExpirationPeriodInMinutes": 1440, "configurationQueriesFrequencyInSec": 900, "configurationQueriesTimeoutInSec": 30, "maxAcceptedMetricAgeInSec": 1200, "maxDataEtwDelayInSec": 3, "maxPublicationAttemptsPerMinute": 5, "maxPublicationBytesPerMinute": 10000000, "maxPublicationMetricsPerMinute": 500000, "maxPublicationPackageSizeInBytes": 2500000, "maxRandomPublicationDelayInSec": 25, "metricsSerializationVersion": 4, "minGapBetweenPublicationAttemptsInSec": 5, "publicationTimeoutInSec": 30, "staleMonitoringAccountsPeriodInMinutes": 20, "internalMetricPublicationTimeoutInMinutes": 20, "dnsResolutionPeriodInSec": 180, "maxAggregationQueueSize": 500000, "initialAccountConfigurationLoadWaitPeriodInSec": 20, "etwMinBuffersPerCore": 2, "etwMaxBuffersPerCore": 16, "etwBufferSizeInKb": 1024, "internalQueueSizeManagementPeriodInSec": 900, "etwLateHeartbeatAllowedCycleCount": 24, "etwSampleRatio": 0, "maxAcceptedMetricFutureAgeInSec": 1200, "aggregatedMetricDiagnosticTracePeriod": 900, "aggregatedMetricDiagnosticTraceMaxSize": 100, "enableMetricMetadataPublication": true, "enableDimensionTrimming": true, "shutdownRequestedThreshold": 5, "internalMetricProductionLevel": 0, "maxPublicationWithoutResponseTimeoutInSec": 300, "maxConfigQueryWithoutResponseTimeoutInSec": 300, "maxThumbprintsPerAccountToLoad": 100, "maxPacketsToCaptureLocally": 0, "maxNumberOfRawEventsPerCycle": 1000000, "publicationSimulated": false, "maxAggregationTimeoutPerCycleInSec": 20, "maxRawEventInputQueueSize": 2000000, "publicationIntervalInSec": 60, "interningSwapPeriodInMin": 240, "interningClearPeriodInMin": 5, "enableParallelization": true, "enableDimensionSortingOnIngestion": true, "rawEtwEventProcessingParallelizationFactor": 1, "maxRandomConfigurationLoadingDelayInSec": 120, "aggregationProcessingParallelizationFactor": 1, "aggregationProcessingPerPartitionPeriodInSec": 20, "aggregationProcessingParallelizationVolumeThreshold": 500000, "useSharedHttpClients": true, "loadFromConfigurationCache": true, "restartByDateTimeUtc": "0001-01-01T00:00:00", "restartStableIdTarget": "", "enableIpV6": false, "disableCustomMetricAgeSupport": false, "globalPublicationCertificateThumbprint": "", "maxHllSerializationVersion": 2, "enableNodeOwnerMode": false, "performAdditionalAzureHostIpV6Checks": false, "compressMetricData": false, "publishMinMaxByDefault": true, "azureResourceId": "'''+ az_resource_id +'''", "aadAuthority": "'''+ aad_url +'''", "aadTokenEnvVariable": "MSIAuthToken" } ''' return conf_json def create_custom_metrics_conf(mds_gig_endpoint_region, gig_endpoint = ""): """ Create the metrics extension config :param mds_gig_endpoint_region: mds gig endpoint region for the VM """ # Note : mds gig endpoint url is only for 3rd party customers. 1st party endpoint is different if not gig_endpoint: gig_hostname = mds_gig_endpoint_region + ".monitoring.azure.com" gig_ingestion_endpoint = "https://" + gig_hostname + "/api/v1/ingestion/ingest" else: gig_hostname = urlparse(gig_endpoint).netloc gig_ingestion_endpoint = gig_endpoint + "/api/v1/ingestion/ingest" conf_json = '''{ "version": 17, "maxMetricAgeInSeconds": 0, "endpointsForClientForking": [], "homeStampGslbHostname": "''' + gig_hostname + '''", "endpointsForClientPublication": [ "''' + gig_ingestion_endpoint + '''" ] } ''' return conf_json def get_handler_vars(): """ This method is taken from the Waagent code. This is used to grab the log and config file location from the json public setting for the Extension """ logFolder = "" configFolder = "" handler_env_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'HandlerEnvironment.json')) if os.path.exists(handler_env_path): with open(handler_env_path, 'r') as handler_env_file: handler_env_txt = handler_env_file.read() handler_env = json.loads(handler_env_txt) if type(handler_env) == list: handler_env = handler_env[0] if "handlerEnvironment" in handler_env: if "logFolder" in handler_env["handlerEnvironment"]: logFolder = handler_env["handlerEnvironment"]["logFolder"] if "configFolder" in handler_env["handlerEnvironment"]: configFolder = handler_env["handlerEnvironment"]["configFolder"] return logFolder, configFolder def get_imds_values(is_lad, HUtilObj=None): """ Query imds to get required values for MetricsExtension config for this VM """ retries = 1 max_retries = 3 sleep_time = 5 imds_url = "" is_arc = False if is_lad: imds_url = "http://169.254.169.254/metadata/instance?api-version=2019-03-11" else: if metrics_utils.is_arc_installed(): imds_url = metrics_utils.get_arc_endpoint() imds_url += "/metadata/instance?api-version=2019-11-01" is_arc = True else: imds_url = "http://169.254.169.254/metadata/instance?api-version=2019-03-11" message = "IMDS url to query: " + imds_url if HUtilObj is not None: HUtilObj.log(message) else: print('Info: {0}'.format(message)) data = None while retries <= max_retries: try: req = urllib.Request(imds_url, headers={'Metadata':'true'}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) except: pass if "compute" not in data: retries += 1 else: break time.sleep(sleep_time) if retries > max_retries: raise Exception("Unable to find 'compute' key in imds query response. Reached max retry limit of - {0} times. Failed to set up ME.".format(max_retries)) if "resourceId" not in data["compute"]: raise Exception("Unable to find 'resourceId' key in imds query response. Failed to set up ME.") az_resource_id = data["compute"]["resourceId"] if "subscriptionId" not in data["compute"]: raise Exception("Unable to find 'subscriptionId' key in imds query response. Failed to set up ME.") subscription_id = data["compute"]["subscriptionId"] if "location" not in data["compute"]: raise Exception("Unable to find 'location' key in imds query response. Failed to set up ME.") location = data["compute"]["location"] if "azEnvironment" not in data["compute"]: raise Exception("Unable to find 'azEnvironment' key in imds query response. Failed to set up ME.") az_environment = data["compute"]["azEnvironment"] return az_resource_id, subscription_id, location, az_environment, data def get_arca_endpoints_from_himds(): """ Query himds to get required arca endpoints for MetricsExtension config for this connected machine """ retries = 1 max_retries = 3 sleep_time = 5 imds_url = "http://localhost:40342/metadata/endpoints?api-version=2019-11-01" if metrics_utils.is_arc_installed(): imds_url = metrics_utils.get_arc_endpoint() imds_url += "/metadata/endpoints?api-version=2019-11-01" data = None while retries <= max_retries: try: req = urllib.Request(imds_url, headers={'Metadata':'true'}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) except: pass if "dataplaneEndpoints" not in data or "resourceManager" not in data: retries += 1 else: break time.sleep(sleep_time) if retries > max_retries: raise Exception("Unable to find 'dataplaneEndpoints' key in imds query response. Reached max retry limit of - {0} times. Failed to set up ME.".format(max_retries)) if "arcMonitorControlServiceEndpoint" not in data["dataplaneEndpoints"]: raise Exception("Unable to find 'arcMonitorControlServiceEndpoint' key in imds query response. Failed to set up ME.") mcs_endpoint = data["dataplaneEndpoints"]["arcMonitorControlServiceEndpoint"] arm_endpoint = data["resourceManager"] return arm_endpoint, mcs_endpoint def get_arca_ingestion_endpoint_from_mcs(): """ Query himds to get required arca endpoints for MetricsExtension config for this connected machine """ retries = 1 max_retries = 3 sleep_time = 5 _, mcs_endpoint = get_arca_endpoints_from_himds() az_resource_id, _, _, _, _ = get_imds_values(False) msi_token_fetched, mcs_token, log_messages = get_ArcA_MSI_token() if not msi_token_fetched: raise Exception("Unable to fetch MCS token, error message: " + log_messages) mcs_config_query_url = mcs_endpoint + az_resource_id + "/agentConfigurations?platform=linux&includeMeConfig=true&api-version=2022-06-02" if not mcs_token.lower().startswith("bearer "): mcs_token = "Bearer " + mcs_token data = None while retries <= max_retries: # Query imds to get the required information req = urllib.Request(mcs_config_query_url, headers={'Metadata':'true', 'Authorization':mcs_token}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) if "configurations" not in data: retries += 1 else: break time.sleep(sleep_time) if retries > max_retries: raise Exception("Unable to find 'configurations' key in amcs query response. Reached max retry limit of - {0} times. Failed to set up ME.".format(max_retries)) if "content" not in data["configurations"][0]: raise Exception("Unable to find 'content' key in amcs query response. Failed to set up ME.") if "channels" not in data["configurations"][0]["content"]: raise Exception("Unable to find 'channels' key in amcs query response. Failed to set up ME.") if "endpoint" not in data["configurations"][0]["content"]["channels"][0]: raise Exception("Unable to find 'endpoint' key in amcs query response. Failed to set up ME.") ingestion_endpoint = data["configurations"][0]["content"]["channels"][0]["endpoint"] # try: # gig_hostname = urllib.parse.urlparse(ingestion_endpoint).netloc # except Exception as e: # raise Exception("Failed to retrieve ingestion host name with Exception='{0}'. ".format(e)) return ingestion_endpoint def get_arm_domain(az_environment): """ Return the ARM domain to use based on the Azure environment """ try: if az_environment.lower() == ArcACloudName: arm_endpoint, _ = get_arca_endpoints_from_himds() arm_endpoint_parsed = urlparse(arm_endpoint) domain = arm_endpoint_parsed.netloc else: domain = ARMDomainMap[az_environment.lower()] except KeyError: raise Exception("Unknown cloud environment \"{0}\". Failed to set up ME.".format(az_environment)) return domain def get_metrics_extension_service_path(is_lad): """ Utility method to get the service path """ if(is_lad): if os.path.exists("/lib/systemd/system/"): return metrics_constants.lad_metrics_extension_service_path elif os.path.exists("/usr/lib/systemd/system/"): return metrics_constants.lad_metrics_extension_service_path_usr_lib else: raise Exception("Systemd unit files do not exist at /lib/systemd/system or /usr/lib/systemd/system/. Failed to setup Metrics Extension service.") else: if os.path.exists("/etc/systemd/system"): return metrics_constants.metrics_extension_service_path_etc if os.path.exists("/lib/systemd/system/"): return metrics_constants.metrics_extension_service_path elif os.path.exists("/usr/lib/systemd/system/"): return metrics_constants.metrics_extension_service_path_usr_lib else: raise Exception("Systemd unit files do not exist at /etc/systemd/system, /lib/systemd/system or /usr/lib/systemd/system/. Failed to setup Metrics Extension service.") def get_metrics_extension_service_name(is_lad): """ Utility method to get the service name """ if(is_lad): return metrics_constants.lad_metrics_extension_service_name else: return metrics_constants.metrics_extension_service_name def setup_me(is_lad, managed_identity="sai", HUtilObj=None, is_local_control_channel=True, user=None, group=None): """ The main method for creating and writing MetricsExtension configuration as well as service setup :param is_lad: Boolean value for whether the extension is Lad or not (AMA) :param is_local_control_channel: Boolean value for whether MetricsExtension needs to be run in `-LocalControlChannel` mode (CMv1 only) :param user: User that would own MetricsExtension process. If not specified, would default to the caller, in this case being root :param group: Group that would own MetricsExtension process. If not specified, would default to the caller, in this case being root """ _, config_folder = get_handler_vars() me_config_dir = config_folder + "/metrics_configs/" create_empty_data_directory(me_config_dir) if not is_local_control_channel: # CMv2 and related modes me_monitoring_account = "" if user and group: # Remove any previous user setup for MetricsExtension if it exists remove_user(user, HUtilObj=HUtilObj) # Create user/group for metrics-extension.service if it is requested ensure_user_and_group(user, group, create_if_missing=True, HUtilObj=HUtilObj) # For ARC, add user to himds group if it exists ensure_user_and_group(user, "himds", create_if_missing=False, HUtilObj=HUtilObj) # In CMv2 with user and group specified, create directory for MetricsExtension config caching me_config_dir = "/var/run/azuremetricsext" create_empty_data_directory(me_config_dir, user, group, HUtilObj=HUtilObj) else: # query imds to get the required information az_resource_id, subscription_id, location, az_environment, data = get_imds_values(is_lad) arm_domain = get_arm_domain(az_environment) # get tenantID # The url request will fail due to missing authentication header, but we get the auth url from the header of the request fail exception aad_auth_url = "" arm_url = "https://{0}/subscriptions/{1}?api-version=2014-04-01".format(arm_domain, subscription_id) try: req = urllib.Request(arm_url, headers={'Content-Type':'application/json'}) res = urllib.urlopen(req) except urlerror.HTTPError as e: err_res = e.headers["WWW-Authenticate"] for line in err_res.split(","): if "Bearer authorization_uri" in line: data = line.split("=") aad_auth_url = data[1][1:-1] # Removing the quotes from the front and back break except Exception as e: message = "Failed to retrieve AAD Authentication URL from " + arm_url + " with Exception='{0}'. ".format(e) message += "Continuing with metrics setup without AAD auth url." if HUtilObj is not None: HUtilObj.log(message) else: print('Info: {0}'.format(message)) #create metrics conf me_conf = create_metrics_extension_conf(az_resource_id, aad_auth_url) #create custom metrics conf if az_environment.lower() == ArcACloudName: ingestion_endpoint = get_arca_ingestion_endpoint_from_mcs() custom_conf = create_custom_metrics_conf(location, ingestion_endpoint) else: custom_conf = create_custom_metrics_conf(location) #write configs to disk me_conf_path = me_config_dir + "MetricsExtensionV1_Configuration.json" with open(me_conf_path, "w") as f: f.write(me_conf) if is_lad: me_monitoring_account = "CUSTOMMETRIC_"+ subscription_id else: me_monitoring_account = "CUSTOMMETRIC_"+ subscription_id + "_" +location custom_conf_path = me_config_dir + me_monitoring_account.lower() +"_MonitoringAccount_Configuration.json" with open(custom_conf_path, "w") as f: f.write(custom_conf) # Copy MetricsExtension Binary to the bin location me_bin_local_path = os.getcwd() + "/MetricsExtensionBin/MetricsExtension" if is_lad: metrics_ext_bin = metrics_constants.lad_metrics_extension_bin else: metrics_ext_bin = metrics_constants.ama_metrics_extension_bin if is_lad: lad_bin_path = "/usr/local/lad/bin/" # Checking if directory exists before copying ME bin over to /usr/local/lad/bin/ if not os.path.exists(lad_bin_path): os.makedirs(lad_bin_path) # Check if previous file exist at the location, compare the two binaries, # If the files are not same, remove the older file, and copy the new one # If they are the same, then we ignore it and don't copy if os.path.isfile(me_bin_local_path): if os.path.isfile(metrics_ext_bin): if not filecmp.cmp(me_bin_local_path, metrics_ext_bin): # Removing the file in case it is already being run in a process, # in which case we can get an error "text file busy" while copying os.remove(metrics_ext_bin) copyfile(me_bin_local_path, metrics_ext_bin) os.chmod(metrics_ext_bin, stat.S_IXGRP | stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IXOTH | stat.S_IROTH) else: # No previous binary exist, simply copy it and make it executable copyfile(me_bin_local_path, metrics_ext_bin) os.chmod(metrics_ext_bin, stat.S_IXGRP | stat.S_IRGRP | stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IXOTH | stat.S_IROTH) else: raise Exception("Unable to copy MetricsExtension Binary, could not find file at the location {0} . Failed to set up ME.".format(me_bin_local_path)) if is_lad: me_influx_port = metrics_constants.lad_metrics_extension_udp_port else: me_influx_port = metrics_constants.ama_metrics_extension_udp_port # setup metrics extension service # If the VM has systemd, then we use that to start/stop if metrics_utils.is_systemd(): setup_me_service(is_lad, me_config_dir, me_monitoring_account, metrics_ext_bin, me_influx_port, managed_identity, HUtilObj) return True def remove_user(user, HUtilObj=None): """ Removes existing user. Note: This is important as the older MetricsExtension might have created the user which needs to be removed. This mechanism can be removed in the future, if the user and group are maintained from MetricsExtension package. :param user: linux user :param HUtilObj: utility object for logging """ try: pwd.getpwnam(user) except KeyError: if HUtilObj: HUtilObj.log('User {0} does not exist.'.format(user)) return except Exception as e: if HUtilObj: HUtilObj.log('Error while checking user {0}: {1}'.format(user, e)) return try: process = subprocess.Popen(['userdel', "-r", user], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = process.communicate() if process.returncode != 0: if HUtilObj: HUtilObj.log('Failed to delete user {0}. stderr: {1}'.format(user, err)) except Exception as e: if HUtilObj: HUtilObj.log('Error while deleting user {0}: {1}'.format(user, e)) def ensure_user_and_group(user, group, create_if_missing=False, HUtilObj=None): """ Ensures if the user and group exists, optionally creating them if it does not exist. Group is checked, user is checked and then user is added to the group. Returns True if all of them are available (or created), else returns False. :param user: linux user :param group: linux group :param create_if_missing: boolean if true, create the requested user and group, where user belongs to the group :param HUtilObj: utility object for logging """ # Check/Create group if missing try: grp.getgrnam(group) if HUtilObj: HUtilObj.log('Group {0} exists.'.format(group)) except KeyError: if create_if_missing: try: process = subprocess.Popen(['groupadd', group], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = process.communicate() if process.returncode != 0: if HUtilObj: HUtilObj.log('Failed to create group {0}. stderr: {1}'.format(group, err)) return False if HUtilObj: HUtilObj.log('Group {0} created.'.format(group)) except Exception as e: if HUtilObj: HUtilObj.log('Error while creating group {0}: {1}'.format(group, e)) return False else: if HUtilObj: HUtilObj.log('Group {0} does not exist.'.format(group)) return False except Exception as e: if HUtilObj: HUtilObj.log('Error while checking group {0}: {1}'.format(group, e)) return False # Check/Create user if missing try: pwd.getpwnam(user) if HUtilObj: HUtilObj.log('User {0} exists.'.format(user)) except KeyError: if create_if_missing: try: process = subprocess.Popen([ 'useradd', '--no-create-home', '--system', '--shell', '/usr/sbin/nologin', user ], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = process.communicate() if process.returncode != 0: if HUtilObj: HUtilObj.log('Failed to create user {0}. stderr: {1}'.format(user, err)) return False if HUtilObj: HUtilObj.log('User {0} created.'.format(user)) except Exception as e: if HUtilObj: HUtilObj.log('Error while creating user {0}: {1}'.format(user, e)) return False else: if HUtilObj: HUtilObj.log('User {0} does not exist.'.format(user)) return False except Exception as e: if HUtilObj: HUtilObj.log('Error while checking user {0}: {1}'.format(user, e)) return False # Add user to group try: process = subprocess.Popen(['usermod', '-aG', group, user], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = process.communicate() if process.returncode != 0: if HUtilObj: HUtilObj.log('Failed to add user {0} to group {1}. stderr: {2}'.format(user, group, err)) return False if HUtilObj: HUtilObj.log('User {0} added to group {1}.'.format(user, group)) except Exception as e: if HUtilObj: HUtilObj.log('Error while adding user {0} to group {1}: {2}'.format(user, group, e)) return False if HUtilObj: HUtilObj.log('User {0} added to group {1} (or already a member).'.format(user, group)) return True def create_empty_data_directory(me_config_dir, user=None, group=None, mode=0o755, HUtilObj=None): ''' Creates an empty data directory where MetricsExtension can store cached configurations. For CMv1, MetricsExtension requires mdsd to provide all configurations on disk. For CMv2, MetricsExtension requires an empty data directory where it can cache its configurations. ''' try: # Clear older config directory if exists. if os.path.exists(me_config_dir): rmtree(me_config_dir) os.makedirs(me_config_dir, mode=mode) if user and group: # Get UID and GID from user and group names uid = pwd.getpwnam(user).pw_uid gid = grp.getgrnam(group).gr_gid # Set the ownership os.chown(me_config_dir, uid, gid) if HUtilObj: HUtilObj.log('Directory {0} created with ownership {1}:{2}.'.format(me_config_dir, user, group)) except Exception as e: if HUtilObj: HUtilObj.log('Failed to create directory: {0}'.format(e)) ================================================ FILE: LAD-AMA-Common/telegraf_utils/__init__.py ================================================ # Telegraf config parser module package ================================================ FILE: LAD-AMA-Common/telegraf_utils/telegraf_config_handler.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import sys import json import os from telegraf_utils.telegraf_name_map import name_map import subprocess import signal from shutil import copyfile, rmtree import time import metrics_ext_utils.metrics_constants as metrics_constants import metrics_ext_utils.metrics_common_utils as metrics_utils try: # Python 3+ import urllib.request as urllib except ImportError: # Python 2 import urllib2 as urllib """ Sample input data received by this script [ { "displayName" : "Network->Packets sent", "interval" : "15s", "sink" : ["mdsd" , "me"] }, { "displayName" : "Network->Packets recieved", "interval" : "15s", "sink" : ["mdsd" , "me"] } ] """ def parse_config(data, me_url, mdsd_url, is_lad, az_resource_id, subscription_id, resource_group, region, virtual_machine_name): """ Main parser method to convert Metrics config from extension configuration to telegraf configuration :param data: Parsed Metrics Configuration from which telegraf config is created :param me_url: The url to which telegraf will send metrics to for MetricsExtension :param mdsd_url: The url to which telegraf will send metrics to for MDSD :param is_lad: Boolean value for whether the extension is Lad or not (AMA) :param az_resource_id: Azure Resource ID value for the VM :param subscription_id: Azure Subscription ID value for the VM :param resource_group: Azure Resource Group value for the VM :param region: Azure Region value for the VM :param virtual_machine_name: Azure Virtual Machine Name value (Only in the case for VMSS) for the VM """ storage_namepass_list = [] storage_namepass_str = "" vmi_rate_counters_list = ["LogicalDisk\\BytesPerSecond", "LogicalDisk\\ReadBytesPerSecond", "LogicalDisk\\ReadsPerSecond", "LogicalDisk\\WriteBytesPerSecond", "LogicalDisk\\WritesPerSecond", "LogicalDisk\\TransfersPerSecond", "Network\\ReadBytesPerSecond", "Network\\WriteBytesPerSecond"] MetricsExtensionNamepsace = metrics_constants.metrics_extension_namespace has_mdsd_output = False has_me_output = False if len(data) == 0: raise Exception("Empty config data received.") if me_url is None or mdsd_url is None: raise Exception("No url provided for Influxdb output plugin to ME, AMA.") telegraf_json = {} counterConfigIdMap = {} for item in data: sink = item["sink"] if "mdsd" in sink: has_mdsd_output = True if "me" in sink: has_me_output = True counter = item["displayName"] if counter in name_map: plugin = name_map[counter]["plugin"] is_vmi = plugin.endswith("_vmi") telegraf_plugin = plugin if is_vmi: splitResult = plugin.split('_') telegraf_plugin = splitResult[0] if counter not in counterConfigIdMap: counterConfigIdMap[counter] = [] configIds = counterConfigIdMap[counter] configurationIds = item["configurationId"] for configId in configurationIds: if configId not in configIds: configIds.append(configId) omiclass = "" if is_lad: omiclass = counter.split("->")[0] else: omiclass = name_map[counter]["module"] if omiclass not in telegraf_json: telegraf_json[omiclass] = {} if plugin not in telegraf_json[omiclass]: telegraf_json[omiclass][plugin] = {} telegraf_json[omiclass][plugin][name_map[counter]["field"]] = {} if is_lad: telegraf_json[omiclass][plugin][name_map[counter]["field"]]["displayName"] = counter.split("->")[1] else: telegraf_json[omiclass][plugin][name_map[counter]["field"]]["displayName"] = counter telegraf_json[omiclass][plugin][name_map[counter]["field"]]["interval"] = item["interval"] if is_lad: telegraf_json[omiclass][plugin][name_map[counter]["field"]]["ladtablekey"] = name_map[counter]["ladtablekey"] if "op" in name_map[counter]: telegraf_json[omiclass][plugin][name_map[counter]["field"]]["op"] = name_map[counter]["op"] """ Sample converted telegraf conf dict - "network": { "net": { "bytes_total": {"interval": "15s","displayName": "Network total bytes","ladtablekey": "/builtin/network/bytestotal"}, "drop_total": {"interval": "15s","displayName": "Network collisions","ladtablekey": "/builtin/network/totalcollisions"}, "err_in": {"interval": "15s","displayName": "Packets received errors","ladtablekey": "/builtin/network/totalrxerrors"}, "packets_sent": {"interval": "15s","displayName": "Packets sent","ladtablekey": "/builtin/network/packetstransmitted"}, } }, "filesystem": { "disk": { "used_percent": {"interval": "15s","displayName": "Filesystem % used space","ladtablekey": "/builtin/filesystem/percentusedspace"}, "used": {"interval": "15s","displayName": "Filesystem used space","ladtablekey": "/builtin/filesystem/usedspace"}, "free": {"interval": "15s","displayName": "Filesystem free space","ladtablekey": "/builtin/filesystem/freespace"}, "inodes_free_percent": {"interval": "15s","displayName": "Filesystem % free inodes","ladtablekey": "/builtin/filesystem/percentfreeinodes"}, }, "diskio": { "writes_filesystem": {"interval": "15s","displayName": "Filesystem writes/sec","ladtablekey": "/builtin/filesystem/writespersecond","op": "rate"}, "total_transfers_filesystem": {"interval": "15s","displayName": "Filesystem transfers/sec","ladtablekey": "/builtin/filesystem/transferspersecond","op": "rate"}, "reads_filesystem": {"interval": "15s","displayName": "Filesystem reads/sec","ladtablekey": "/builtin/filesystem/readspersecond","op": "rate"}, } }, """ if len(telegraf_json) == 0: raise Exception("Unable to parse telegraf config into intermediate dictionary.") excess_diskio_plugin_list_lad = ["total_transfers_filesystem", "read_bytes_filesystem", "total_bytes_filesystem", "write_bytes_filesystem", "reads_filesystem", "writes_filesystem"] excess_diskio_field_drop_list_str = "" int_file = {"filename":"intermediate.json", "data": json.dumps(telegraf_json)} output = [] output.append(int_file) for omiclass in telegraf_json: input_str = "" ama_rename_str = "" metricsext_rename_str = "" lad_specific_rename_str = "" rate_specific_aggregator_str = "" aggregator_str = "" for plugin in telegraf_json[omiclass]: config_file = {"filename" : omiclass+".conf"} # Arbitrary max value for finding min min_interval = "999999999s" is_vmi = plugin.endswith("_vmi") is_vmi_rate_counter = False for field in telegraf_json[omiclass][plugin]: if not is_vmi_rate_counter: is_vmi_rate_counter = telegraf_json[omiclass][plugin][field]["displayName"] in vmi_rate_counters_list # if is_vmi_rate_counter: # min_interval = "1s" if is_vmi or is_vmi_rate_counter: splitResult = plugin.split('_') telegraf_plugin = splitResult[0] input_str += "[[inputs." + telegraf_plugin + "]]\n" # plugin = plugin[:-4] else: input_str += "[[inputs." + plugin + "]]\n" # input_str += " "*2 + "name_override = \"" + omiclass + "\"\n" # If it's a lad config then add the namepass fields for sending totals to storage # always skip lad plugin names as they should be dropped from ME lad_plugin_name = plugin + "_total" if lad_plugin_name not in storage_namepass_list: storage_namepass_list.append(lad_plugin_name) if is_lad: lad_specific_rename_str += "\n[[processors.rename]]\n" lad_specific_rename_str += " "*2 + "namepass = [\"" + lad_plugin_name + "\"]\n" elif is_vmi or is_vmi_rate_counter: if plugin not in storage_namepass_list: storage_namepass_list.append(plugin + "_mdsd") else: ama_plugin_name = plugin + "_mdsd_la_perf" ama_rename_str += "\n[[processors.rename]]\n" ama_rename_str += " "*2 + "namepass = [\"" + ama_plugin_name + "\"]\n" if ama_plugin_name not in storage_namepass_list: storage_namepass_list.append(ama_plugin_name) namespace = MetricsExtensionNamepsace if is_vmi or is_vmi_rate_counter: namespace = "insights.virtualmachine" if is_vmi_rate_counter: # Adding "_rated" as a substring for vmi rate metrics to avoid renaming collisions plugin_name = plugin + "_rated" else: plugin_name = plugin metricsext_rename_str += "\n[[processors.rename]]\n" metricsext_rename_str += " "*2 + "namepass = [\"" + plugin_name + "\"]\n" metricsext_rename_str += "\n" + " "*2 + "[[processors.rename.replace]]\n" metricsext_rename_str += " "*4 + "measurement = \"" + plugin_name + "\"\n" metricsext_rename_str += " "*4 + "dest = \"" + namespace + "\"\n" fields = "" ops_fields = "" non_ops_fields = "" non_rate_aggregate = False ops = "" rate_aggregate = False for field in telegraf_json[omiclass][plugin]: fields += "\"" + field + "\", " if is_vmi or is_vmi_rate_counter : if "MB" in field: fields += "\"" + field.replace('MB','Bytes') + "\", " #Use the shortest interval time for the whole plugin new_interval = telegraf_json[omiclass][plugin][field]["interval"] if int(new_interval[:-1]) < int(min_interval[:-1]): min_interval = new_interval #compute values for aggregator options if "op" in telegraf_json[omiclass][plugin][field]: if telegraf_json[omiclass][plugin][field]["op"] == "rate": rate_aggregate = True ops = "\"rate\", \"rate_min\", \"rate_max\", \"rate_count\", \"rate_sum\", \"rate_mean\"" if is_lad: ops_fields += "\"" + telegraf_json[omiclass][plugin][field]["ladtablekey"] + "\", " else: ops_fields += "\"" + telegraf_json[omiclass][plugin][field]["displayName"] + "\", " else: non_rate_aggregate = True if is_lad: non_ops_fields += "\"" + telegraf_json[omiclass][plugin][field]["ladtablekey"] + "\", " else: non_ops_fields += "\"" + telegraf_json[omiclass][plugin][field]["displayName"] + "\", " #Add respective rename processor plugin based on the displayname if is_lad: lad_specific_rename_str += "\n" + " "*2 + "[[processors.rename.replace]]\n" lad_specific_rename_str += " "*4 + "field = \"" + field + "\"\n" lad_specific_rename_str += " "*4 + "dest = \"" + telegraf_json[omiclass][plugin][field]["ladtablekey"] + "\"\n" elif not is_vmi and not is_vmi_rate_counter: # no rename of fields as they are set in telegraf directly ama_rename_str += "\n" + " "*2 + "[[processors.rename.replace]]\n" ama_rename_str += " "*4 + "field = \"" + field + "\"\n" ama_rename_str += " "*4 + "dest = \"" + telegraf_json[omiclass][plugin][field]["displayName"] + "\"\n" # Avoid adding the rename logic for the redundant *_filesystem fields for diskio which were added specifically for OMI parity in LAD # Had to re-use these six fields to avoid renaming issues since both Filesystem and Disk in OMI-LAD use them # AMA only uses them once so only need this for LAD if is_lad: if field in excess_diskio_plugin_list_lad: excess_diskio_field_drop_list_str += "\"" + field + "\", " else: metricsext_rename_str += "\n" + " "*2 + "[[processors.rename.replace]]\n" metricsext_rename_str += " "*4 + "field = \"" + field + "\"\n" metricsext_rename_str += " "*4 + "dest = \"" + plugin + "/" + field + "\"\n" elif not is_vmi and not is_vmi_rate_counter: # no rename of fields as they are set in telegraf directly metricsext_rename_str += "\n" + " "*2 + "[[processors.rename.replace]]\n" metricsext_rename_str += " "*4 + "field = \"" + field + "\"\n" metricsext_rename_str += " "*4 + "dest = \"" + plugin + "/" + field + "\"\n" #Add respective operations for aggregators # if is_lad: if not is_vmi and not is_vmi_rate_counter: suffix = "" if is_lad: suffix = "_total\"]\n" else: suffix = "_mdsd_la_perf\"]\n" if rate_aggregate: aggregator_str += "[[aggregators.basicstats]]\n" aggregator_str += " "*2 + "namepass = [\"" + plugin + suffix aggregator_str += " "*2 + "period = \"" + min_interval + "\"\n" aggregator_str += " "*2 + "drop_original = true\n" aggregator_str += " "*2 + "fieldpass = [" + ops_fields[:-2] + "]\n" #-2 to strip the last comma and space aggregator_str += " "*2 + "stats = [" + ops + "]\n" if non_rate_aggregate: aggregator_str += "[[aggregators.basicstats]]\n" aggregator_str += " "*2 + "namepass = [\"" + plugin + suffix aggregator_str += " "*2 + "period = \"" + min_interval + "\"\n" aggregator_str += " "*2 + "drop_original = true\n" aggregator_str += " "*2 + "fieldpass = [" + non_ops_fields[:-2] + "]\n" #-2 to strip the last comma and space aggregator_str += " "*2 + "stats = [\"mean\", \"max\", \"min\", \"sum\", \"count\"]\n\n" elif is_vmi_rate_counter: # Aggregator config for MDSD aggregator_str += "[[aggregators.basicstats]]\n" aggregator_str += " "*2 + "namepass = [\"" + plugin + "_mdsd\"]\n" aggregator_str += " "*2 + "period = \"" + min_interval + "\"\n" aggregator_str += " "*2 + "drop_original = true\n" aggregator_str += " "*2 + "fieldpass = [" + ops_fields[:-2].replace('\\','\\\\\\\\') + "]\n" #-2 to strip the last comma and space aggregator_str += " "*2 + "stats = [" + ops + "]\n\n" # Aggregator config for ME aggregator_str += "[[aggregators.mdmratemetrics]]\n" aggregator_str += " "*2 + "namepass = [\"" + plugin + "\"]\n" aggregator_str += " "*2 + "period = \"" + min_interval + "\"\n" aggregator_str += " "*2 + "drop_original = true\n" aggregator_str += " "*2 + "fieldpass = [" + ops_fields[:-2].replace('\\','\\\\\\\\') + "]\n" #-2 to strip the last comma and space aggregator_str += " "*2 + "stats = [\"rate\"]\n\n" if is_lad: lad_specific_rename_str += "\n" elif not is_vmi and not is_vmi_rate_counter: # no rename of fields as they are set in telegraf directly ama_rename_str += "\n" # Using fields[: -2] here to get rid of the last ", " at the end of the string input_str += " "*2 + "fieldpass = ["+fields[:-2]+"]\n" if plugin == "cpu": input_str += " "*2 + "report_active = true\n" # Rate interval needs to be atleast twice the regular sourcing interval for aggregation to work. # Since we want all the VMI metrics to be sent at the same interval as selected by the customer, To overcome the twice the min internval limitation, # We are sourcing the VMI metrics that need to be aggregated at half the selected frequency rated_min_interval = str(int(min_interval[:-1]) // 2) + "s" input_str += " "*2 + "interval = " + "\"" + rated_min_interval + "\"\n\n" telegraf_plugin = plugin if is_vmi: splitResult = plugin.split('_') telegraf_plugin = splitResult[0] if not is_lad: configIds = counterConfigIdMap[telegraf_json[omiclass][plugin][field]["displayName"]] for configId in configIds: input_str += "\n" input_str += " "*2 + "[inputs." + telegraf_plugin + ".tags]\n" input_str += " "*4 + "configurationId=\"" + configId + "\"\n\n" break config_file["data"] = input_str + "\n" + metricsext_rename_str + "\n" + ama_rename_str + "\n" + lad_specific_rename_str + "\n" +aggregator_str output.append(config_file) config_file = {} """ Sample telegraf TOML file output [[inputs.net]] fieldpass = ["err_out", "packets_sent", "err_in", "bytes_sent", "packets_recv"] interval = "5s" [[inputs.cpu]] fieldpass = ["usage_nice", "usage_user", "usage_idle", "usage_active", "usage_irq", "usage_system"] interval = "15s" [[processors.rename]] [[processors.rename.replace]] measurement = "net" dest = "network" [[processors.rename.replace]] field = "err_out" dest = "Packets sent errors" [[aggregators.basicstats]] period = "30s" drop_original = false fieldpass = ["Disk reads", "Disk writes", "Filesystem write bytes/sec"] stats = ["rate"] """ ## Get the log folder directory from HandlerEnvironment.json and use that for the telegraf default logging logFolder, _ = get_handler_vars() for measurement in storage_namepass_list: storage_namepass_str += "\"" + measurement + "\", " # Telegraf basic agent and output config agentconf = "[agent]\n" agentconf += " interval = \"10s\"\n" agentconf += " round_interval = true\n" agentconf += " metric_batch_size = 1000\n" agentconf += " metric_buffer_limit = 1000000\n" agentconf += " collection_jitter = \"0s\"\n" agentconf += " flush_interval = \"10s\"\n" agentconf += " flush_jitter = \"0s\"\n" agentconf += " logtarget = \"file\"\n" agentconf += " quiet = true\n" agentconf += " logfile = \"" + logFolder + "/telegraf.log\"\n" agentconf += " logfile_rotation_max_size = \"100MB\"\n" agentconf += " logfile_rotation_max_archives = 5\n" agentconf += "\n# Configuration for adding gloabl tags\n" agentconf += "[global_tags]\n" if is_lad: agentconf += " DeploymentId= \"${DeploymentId}\"\n" agentconf += " \"microsoft.subscriptionId\"= \"" + subscription_id + "\"\n" agentconf += " \"microsoft.resourceGroupName\"= \"" + resource_group + "\"\n" agentconf += " \"microsoft.regionName\"= \"" + region + "\"\n" agentconf += " \"microsoft.resourceId\"= \"" + az_resource_id + "\"\n" if virtual_machine_name != "": agentconf += " \"VMInstanceId\"= \"" + virtual_machine_name + "\"\n" if has_me_output or is_lad: agentconf += "\n# Configuration for sending metrics to MetricsExtension\n" # for AMA we use Sockets to write to ME but for LAD we continue using UDP # because we support a lot more counters in AMA path and ME is not able to handle it with UDP if is_lad: agentconf += "[[outputs.influxdb]]\n" else: agentconf += "[[outputs.socket_writer]]\n" agentconf += " namedrop = [" + storage_namepass_str[:-2] + "]\n" if is_lad: agentconf += " fielddrop = [" + excess_diskio_field_drop_list_str[:-2] + "]\n" if is_lad: agentconf += " urls = [\"" + str(me_url) + "\"]\n\n" agentconf += " udp_payload = \"2048B\"\n\n" else: agentconf += " data_format = \"influx\"\n" agentconf += " address = \"" + str(me_url) + "\"\n\n" if has_mdsd_output: agentconf += "\n# Configuration for sending metrics to MDSD\n" agentconf += "[[outputs.socket_writer]]\n" agentconf += " namepass = [" + storage_namepass_str[:-2] + "]\n" agentconf += " data_format = \"influx\"\n" agentconf += " address = \"" + str(mdsd_url) + "\"\n\n" agentconf += "\n# Configuration for outputing metrics to file. Uncomment to enable.\n" agentconf += "#[[outputs.file]]\n" agentconf += "# files = [\"./metrics_to_file.out\"]\n\n" agent_file = {"filename":"telegraf.conf", "data": agentconf} output.append(agent_file) return output, storage_namepass_list def write_configs(configs, telegraf_conf_dir, telegraf_d_conf_dir): """ Write the telegraf config created by config parser method to disk at the telegraf config location :param configs: Telegraf config data parsed by the parse_config method above :param telegraf_conf_dir: Path where the telegraf.conf is written to on the disk :param telegraf_d_conf_dir: Path where the individual module telegraf configs are written to on the disk """ # Delete the older config folder to prevent telegraf from loading older configs if os.path.exists(telegraf_conf_dir): rmtree(telegraf_conf_dir) os.mkdir(telegraf_conf_dir) os.mkdir(telegraf_d_conf_dir) for configfile in configs: if configfile["filename"] == "telegraf.conf" or configfile["filename"] == "intermediate.json": path = telegraf_conf_dir + configfile["filename"] else: path = telegraf_d_conf_dir + configfile["filename"] with open(path, "w") as f: f.write(configfile["data"]) def get_handler_vars(): """ This method is taken from the Waagent code. This is used to grab the log and config file location from the json public setting for the Extension """ logFolder = "" configFolder = "" handler_env_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'HandlerEnvironment.json')) if os.path.exists(handler_env_path): with open(handler_env_path, 'r') as handler_env_file: handler_env_txt = handler_env_file.read() handler_env = json.loads(handler_env_txt) if type(handler_env) == list: handler_env = handler_env[0] if "handlerEnvironment" in handler_env: if "logFolder" in handler_env["handlerEnvironment"]: logFolder = handler_env["handlerEnvironment"]["logFolder"] if "configFolder" in handler_env["handlerEnvironment"]: configFolder = handler_env["handlerEnvironment"]["configFolder"] return logFolder, configFolder def is_running(is_lad): """ This method is used to check if telegraf binary is currently running on the system or not. In order to check whether it needs to be restarted from the watcher daemon """ if is_lad: telegraf_bin = metrics_constants.lad_telegraf_bin else: telegraf_bin = metrics_constants.ama_telegraf_bin proc = subprocess.Popen(["ps aux | grep telegraf | grep -v grep"], stdout=subprocess.PIPE, shell=True) output = proc.communicate()[0] if telegraf_bin in output.decode('utf-8', 'ignore'): return True else: return False def stop_telegraf_service(is_lad): """ Stop the telegraf service if VM is using is systemd, otherwise check if the pid_file exists, and if the pid belongs to the Telegraf process, if yes, then kill the process This method is called before remove_telegraf_service by the main extension code :param is_lad: boolean whether the extension is LAD or not (AMA) """ if is_lad: telegraf_bin = metrics_constants.lad_telegraf_bin else: telegraf_bin = metrics_constants.ama_telegraf_bin # If the VM has systemd, then we will use that to stop if metrics_utils.is_systemd(): code = 1 telegraf_service_path = get_telegraf_service_path(is_lad) telegraf_service_name = get_telegraf_service_name(is_lad) if os.path.isfile(telegraf_service_path): code = os.system("systemctl stop {0}".format(telegraf_service_name)) else: return False, "Telegraf service file does not exist. Failed to stop telegraf service: {0}.service.".format(telegraf_service_name) if code != 0: return False, "Unable to stop telegraf service: {0}.service. Run systemctl status {0}.service for more info.".format(telegraf_service_name) # Whether or not VM has systemd, let's check if we have any telegraf pids saved and if so, terminate the associated process _, configFolder = get_handler_vars() telegraf_conf_dir = configFolder + "/telegraf_configs/" telegraf_pid_path = telegraf_conf_dir + "telegraf_pid.txt" if os.path.isfile(telegraf_pid_path): with open(telegraf_pid_path, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to telegraf cmd_path = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_path): with open(cmd_path, "r") as cmd_f: cmdline = cmd_f.readlines() if cmdline[0].find(telegraf_bin) >= 0: os.kill(int(pid), signal.SIGKILL) os.remove(telegraf_pid_path) elif not metrics_utils.is_systemd(): return False, "Could not find telegraf service nor process to stop." return True, "Successfully stopped metrics-sourcer service" def remove_telegraf_service(is_lad): """ Remove the telegraf service if the VM is using systemd as well as the telegraf Binary This method is called after stop_telegraf_service by the main extension code during Extension uninstall :param is_lad: boolean whether the extension is LAD or not (AMA) """ telegraf_service_path = get_telegraf_service_path(is_lad) telegraf_service_name = get_telegraf_service_name(is_lad) if os.path.isfile(telegraf_service_path): os.remove(telegraf_service_path) else: return True, "Unable to remove the Telegraf service as the file doesn't exist." # Checking To see if the file was successfully removed, since os.remove doesn't return an error code if os.path.isfile(telegraf_service_path): return False, "Unable to remove telegraf service: {0}.service at {1}.".format(telegraf_service_name, telegraf_service_path) return True, "Successfully removed {0} service".format(telegraf_service_name) def setup_telegraf_service(is_lad, telegraf_bin, telegraf_d_conf_dir, telegraf_agent_conf, HUtilObj=None): """ Add the metrics-sourcer service if the VM is using systemd This method is called in handle_config :param telegraf_bin: path to the telegraf binary :param telegraf_d_conf_dir: path to telegraf .d conf subdirectory :param telegraf_agent_conf: path to telegraf .conf file """ telegraf_service_path = get_telegraf_service_path(is_lad) telegraf_service_template_path = os.getcwd() + "/services/metrics-sourcer.service" if not os.path.exists(telegraf_d_conf_dir): raise Exception("Telegraf config directory does not exist. Failed to setup telegraf service.") if not os.path.isfile(telegraf_agent_conf): raise Exception("Telegraf agent config does not exist. Failed to setup telegraf service.") if os.path.isfile(telegraf_service_template_path): copyfile(telegraf_service_template_path, telegraf_service_path) if os.path.isfile(telegraf_service_path): os.system(r"sed -i 's+%TELEGRAF_BIN%+{1}+' {0}".format(telegraf_service_path, telegraf_bin)) os.system(r"sed -i 's+%TELEGRAF_AGENT_CONFIG%+{1}+' {0}".format(telegraf_service_path, telegraf_agent_conf)) os.system(r"sed -i 's+%TELEGRAF_CONFIG_DIR%+{1}+' {0}".format(telegraf_service_path, telegraf_d_conf_dir)) daemon_reload_status = os.system("systemctl daemon-reload") if daemon_reload_status != 0: message = "Unable to reload systemd after Telegraf service file change. Failed to setup telegraf service. Check system for hardening. Exit code:" + str(daemon_reload_status) if HUtilObj is not None: HUtilObj.log(message) else: print('Info: {0}'.format(message)) else: raise Exception("Unable to copy Telegraf service template file to {0}. Failed to setup telegraf service.".format(telegraf_service_path)) else: raise Exception("Telegraf service template file does not exist at {0}. Failed to setup telegraf service.".format(telegraf_service_template_path)) return True def start_telegraf(is_lad): """ Start the telegraf service if VM is using is systemd, otherwise start the binary as a process and store the pid to a file in the telegraf config directory This method is called after config setup is completed by the main extension code :param is_lad: boolean whether the extension is LAD or not (AMA) """ # Re using the code to grab the config directories and imds values because start will be called from Enable process outside this script log_messages = "" if is_lad: telegraf_bin = metrics_constants.lad_telegraf_bin else: telegraf_bin = metrics_constants.ama_telegraf_bin if not os.path.isfile(telegraf_bin): log_messages += "Telegraf binary does not exist. Failed to start telegraf service." return False, log_messages # Ensure that any old telegraf processes are cleaned up to avoid duplication stop_telegraf_service(is_lad) # If the VM has systemd, telegraf will be managed as a systemd service telegraf_service_name = get_telegraf_service_name(is_lad) if metrics_utils.is_systemd(): service_restart_status = os.system("systemctl restart {0}".format(telegraf_service_name)) if service_restart_status != 0: log_messages += "Unable to start Telegraf service using systemctl. Failed to start telegraf service. Check system for hardening." return False, log_messages # Otherwise, start telegraf as a process and save the pid to a file so that we can terminate it while disabling/uninstalling else: _, configFolder = get_handler_vars() telegraf_conf_dir = configFolder + "/telegraf_configs/" telegraf_agent_conf = telegraf_conf_dir + "telegraf.conf" telegraf_d_conf_dir = telegraf_conf_dir + "telegraf.d/" telegraf_pid_path = telegraf_conf_dir + "telegraf_pid.txt" binary_exec_command = "{0} --config {1} --config-directory {2}".format(telegraf_bin, telegraf_agent_conf, telegraf_d_conf_dir) proc = subprocess.Popen(binary_exec_command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Sleeping for 3 seconds before checking if the process is still running, to give it ample time to relay crash info time.sleep(3) p = proc.poll() # Process is running successfully if p is None: telegraf_pid = proc.pid # Write this pid to a file for future use try: with open(telegraf_pid_path, "a") as f: f.write(str(telegraf_pid) + '\n') except Exception as e: log_messages += "Successfully started telegraf binary, but could not save telegraf pidfile." else: out, err = proc.communicate() log_messages += "Unable to run telegraf binary as a process due to error - {0}. Failed to start telegraf.".format(err) return False, log_messages return True, log_messages def get_telegraf_service_path(is_lad): """ Utility method to get the service path in case /lib/systemd/system doesnt exist on the OS """ if is_lad: if os.path.exists("/lib/systemd/system/"): return metrics_constants.lad_telegraf_service_path elif os.path.exists("/usr/lib/systemd/system/"): return metrics_constants.lad_telegraf_service_path_usr_lib else: raise Exception("Systemd unit files do not exist at /lib/systemd/system or /usr/lib/systemd/system/. Failed to setup telegraf service.") else: if os.path.exists("/lib/systemd/system/"): return metrics_constants.telegraf_service_path elif os.path.exists("/usr/lib/systemd/system/"): return metrics_constants.telegraf_service_path_usr_lib else: raise Exception("Systemd unit files do not exist at /lib/systemd/system or /usr/lib/systemd/system/. Failed to setup telegraf service.") def get_telegraf_service_name(is_lad): """ Utility method to get the service name """ if(is_lad): return metrics_constants.lad_telegraf_service_name else: return metrics_constants.telegraf_service_name def handle_config(config_data, me_url, mdsd_url, is_lad): """ The main method to perfom the task of parsing the config , writing them to disk, setting up, stopping, removing and starting telegraf :param config_data: Parsed Metrics Configuration from which telegraf config is created :param me_url: The url to which telegraf will send metrics to for MetricsExtension :param mdsd_url: The url to which telegraf will send metrics to for MDSD :param is_lad: Boolean value for whether the extension is Lad or not (AMA) """ # Making the imds call to get resource id, sub id, resource group and region for the dimensions for telegraf metrics retries = 1 max_retries = 3 sleep_time = 5 imdsurl = "" is_arc = False if is_lad: imdsurl = "http://169.254.169.254/metadata/instance?api-version=2019-03-11" else: if metrics_utils.is_arc_installed(): imdsurl = metrics_utils.get_arc_endpoint() imdsurl += "/metadata/instance?api-version=2019-11-01" is_arc = True else: imdsurl = "http://169.254.169.254/metadata/instance?api-version=2019-03-11" data = None while retries <= max_retries: req = urllib.Request(imdsurl, headers={'Metadata':'true'}) res = urllib.urlopen(req) data = json.loads(res.read().decode('utf-8', 'ignore')) if "compute" not in data: retries += 1 else: break time.sleep(sleep_time) if retries > max_retries: raise Exception("Unable to find 'compute' key in imds query response. Reached max retry limit of - {0} times. Failed to setup Telegraf.".format(max_retries)) if "resourceId" not in data["compute"]: raise Exception("Unable to find 'resourceId' key in imds query response. Failed to setup Telegraf.") # resource id is needed for ME to show metrics on the metrics blade of the VM/VMSS # ME expected ID- /subscriptions//resourceGroups//providers/Microsoft.Compute/virtualMachineScaleSets/ # or /subscriptions/20ff167c-9f4b-4a73-9fd6-0dbe93fa778a/resourceGroups/sidama/providers/Microsoft.Compute/virtualMachines/syslogReliability_1ec84a39 az_resource_id = data["compute"]["resourceId"] # If the instance is VMSS instance resource id of a uniform VMSS then trim the last two values from the resource id ie - "/virtualMachines/0" # Since ME expects the resource id in a particular format. For egs - # IMDS returned ID - /subscriptions//resourceGroups//providers/Microsoft.Compute/virtualMachineScaleSets//virtualMachines/0 # ME expected ID- /subscriptions//resourceGroups//providers/Microsoft.Compute/virtualMachineScaleSets/ if "virtualMachineScaleSets" in az_resource_id: az_resource_id = "/".join(az_resource_id.split("/")[:-2]) virtual_machine_name = "" if "vmScaleSetName" in data["compute"] and data["compute"]["vmScaleSetName"] != "": virtual_machine_name = data["compute"]["name"] # for flexible VMSS above resource id is instance specific and won't have virtualMachineScaleSets # for e.g., /subscriptions/20ff167c-9f4b-4a73-9fd6-0dbe93fa778a/resourceGroups/sidama/providers/Microsoft.Compute/virtualMachines/syslogReliability_1ec84a39 # ME expected ID- /subscriptions//resourceGroups//providers/Microsoft.Compute/virtualMachineScaleSets/ if "virtualMachineScaleSets" not in az_resource_id: az_resource_id = "/".join(az_resource_id.split("/")[:-2]) + "/virtualMachineScaleSets/" + data["compute"]["vmScaleSetName"] if "subscriptionId" not in data["compute"]: raise Exception("Unable to find 'subscriptionId' key in imds query response. Failed to setup Telegraf.") subscription_id = data["compute"]["subscriptionId"] if "resourceGroupName" not in data["compute"]: raise Exception("Unable to find 'resourceGroupName' key in imds query response. Failed to setup Telegraf.") resource_group = data["compute"]["resourceGroupName"] if "location" not in data["compute"]: raise Exception("Unable to find 'location' key in imds query response. Failed to setup Telegraf.") region = data["compute"]["location"] #call the method to first parse the configs output, namespaces = parse_config(config_data, me_url, mdsd_url, is_lad, az_resource_id, subscription_id, resource_group, region, virtual_machine_name) _, configFolder = get_handler_vars() if is_lad: telegraf_bin = metrics_constants.lad_telegraf_bin else: telegraf_bin = metrics_constants.ama_telegraf_bin telegraf_conf_dir = configFolder + "/telegraf_configs/" telegraf_agent_conf = telegraf_conf_dir + "telegraf.conf" telegraf_d_conf_dir = telegraf_conf_dir + "telegraf.d/" #call the method to write the configs write_configs(output, telegraf_conf_dir, telegraf_d_conf_dir) # Setup Telegraf service. # If the VM has systemd, then we will copy over the systemd unit file and use that to start/stop if metrics_utils.is_systemd(): telegraf_service_setup = setup_telegraf_service(is_lad, telegraf_bin, telegraf_d_conf_dir, telegraf_agent_conf) if not telegraf_service_setup: return False, [] return True, namespaces ================================================ FILE: LAD-AMA-Common/telegraf_utils/telegraf_name_map.py ================================================ #!/usr/bin/env python # # Azure Linux extension # # Copyright (c) Microsoft Corporation # All rights reserved. # MIT License # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the ""Software""), to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the # Software. # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. name_map = { ######These are the counter keys and telegraf plugins for LAD/AMA "processor->cpu io wait time" : {"plugin":"cpu", "field":"usage_iowait", "ladtablekey":"/builtin/processor/percentiowaittime"}, "processor->cpu user time" : {"plugin":"cpu", "field":"usage_user", "ladtablekey":"/builtin/processor/percentusertime"}, "processor->cpu nice time" : {"plugin":"cpu", "field":"usage_nice", "ladtablekey":"/builtin/processor/percentnicetime"}, "processor->cpu percentage guest os" : {"plugin":"cpu", "field":"usage_active", "ladtablekey":"/builtin/processor/percentprocessortime"}, "processor->cpu interrupt time" : {"plugin":"cpu", "field":"usage_irq", "ladtablekey":"/builtin/processor/percentinterrupttime"}, "processor->cpu idle time" : {"plugin":"cpu", "field":"usage_idle", "ladtablekey":"/builtin/processor/percentidletime"}, "processor->cpu privileged time" : {"plugin":"cpu", "field":"usage_system", "ladtablekey":"/builtin/processor/percentprivilegedtime"}, "% IO Wait Time" : {"plugin":"cpu", "field":"usage_iowait", "module":"processor"}, "% User Time" : {"plugin":"cpu", "field":"usage_user", "module":"processor"}, "% Nice Time" : {"plugin":"cpu", "field":"usage_nice", "module":"processor"}, "% Processor Time" : {"plugin":"cpu", "field":"usage_active", "module":"processor"}, "% Interrupt Time" : {"plugin":"cpu", "field":"usage_irq", "module":"processor"}, "% Idle Time" : {"plugin":"cpu", "field":"usage_idle", "module":"processor"}, "% Privileged Time" : {"plugin":"cpu", "field":"usage_system", "module":"processor"}, # VM Insights # 8 slashes because this goes from JSON -> Python -> Telegraf config -> Go -> C++ and each level does an escape "Processor\\UtilizationPercentage" : {"plugin":"cpu_vmi", "field":"Processor\\\\\\\\UtilizationPercentage", "module":"processor"}, "Computer\\Heartbeat" : {"plugin":"cpu_heartbeat_vmi", "field":"Computer\\\\\\\\Heartbeat", "module":"processor"}, "network->network in guest os" : {"plugin":"net", "field":"bytes_recv", "ladtablekey":"/builtin/network/bytesreceived"}, "network->network total bytes" : {"plugin":"net", "field":"bytes_total", "ladtablekey":"/builtin/network/bytestotal"}, #Need to calculate sum "network->network out guest os" : {"plugin":"net", "field":"bytes_sent", "ladtablekey":"/builtin/network/bytestransmitted"}, "network->network collisions" : {"plugin":"net", "field":"drop_total", "ladtablekey":"/builtin/network/totalcollisions"}, #Need to calculate sum "network->packets received errors" : {"plugin":"net", "field":"err_in", "ladtablekey":"/builtin/network/totalrxerrors"}, "network->packets sent" : {"plugin":"net", "field":"packets_sent", "ladtablekey":"/builtin/network/packetstransmitted"}, "network->packets received" : {"plugin":"net", "field":"packets_recv", "ladtablekey":"/builtin/network/packetsreceived"}, "network->packets sent errors" : {"plugin":"net", "field":"err_out", "ladtablekey":"/builtin/network/totaltxerrors"}, "Total Bytes Received" : {"plugin":"net", "field":"bytes_recv", "module":"network"}, "Total Bytes" : {"plugin":"net", "field":"bytes_total", "module":"network"}, #Need to calculate sum "Total Bytes Transmitted" : {"plugin":"net", "field":"bytes_sent", "module":"network"}, "Total Collisions" : {"plugin":"net", "field":"drop_total", "module":"network"}, #Need to calculate sum "Total Rx Errors" : {"plugin":"net", "field":"err_in", "module":"network"}, "Total Packets Transmitted" : {"plugin":"net", "field":"packets_sent", "module":"network"}, "Total Packets Received" : {"plugin":"net", "field":"packets_recv", "module":"network"}, "Total Tx Errors" : {"plugin":"net", "field":"err_out", "module":"network"}, # VM Insights # "Network\ReadBytesPerSecond", "Network\WriteBytesPerSecond" # 8 slashes because this goes from JSON -> Python -> Telegraf config -> Go -> C++ and each level does an escape "Network\\ReadBytesPerSecond" : {"plugin":"net_recv_vmi", "field":"Network\\\\\\\\ReadBytesPerSecond", "op":"rate", "module":"network"}, "Network\\WriteBytesPerSecond" : {"plugin":"net_sent_vmi", "field":"Network\\\\\\\\WriteBytesPerSecond", "op":"rate", "module":"network"}, "memory->memory available" : {"plugin":"mem", "field":"available", "ladtablekey":"/builtin/memory/availablememory"}, "memory->mem. percent available" : {"plugin":"mem", "field":"available_percent", "ladtablekey":"/builtin/memory/percentavailablememory"}, "memory->memory used" : {"plugin":"mem", "field":"used", "ladtablekey":"/builtin/memory/usedmemory"}, "memory->memory percentage" : {"plugin":"mem", "field":"used_percent", "ladtablekey":"/builtin/memory/percentusedmemory"}, "memory->swap available" : {"plugin":"swap", "field":"free", "ladtablekey":"/builtin/memory/availableswap"}, "memory->swap percent available" : {"plugin":"swap", "field":"free_percent", "ladtablekey":"/builtin/memory/percentavailableswap"}, #Need to calculate percentage "memory->swap used" : {"plugin":"swap", "field":"used", "ladtablekey":"/builtin/memory/usedswap"}, "memory->swap percent used" : {"plugin":"swap", "field":"used_percent", "ladtablekey":"/builtin/memory/percentusedswap"}, "memory->page reads": {"plugin":"kernel_vmstat", "field":"pgpgin", "op":"rate", "ladtablekey":"/builtin/memory/pagesreadpersec"}, "memory->page writes" : {"plugin":"kernel_vmstat", "field":"pgpgout", "op":"rate", "ladtablekey":"/builtin/memory/pageswrittenpersec"}, "memory->pages" : {"plugin":"kernel_vmstat", "field":"total_pages", "op":"rate", "ladtablekey":"/builtin/memory/pagespersec"}, "Available MBytes Memory" : {"plugin":"mem", "field":"available", "module":"memory"}, "% Available Memory" : {"plugin":"mem", "field":"available_percent", "module":"memory"}, "Used Memory MBytes" : {"plugin":"mem", "field":"used", "module":"memory"}, "% Used Memory" : {"plugin":"mem", "field":"used_percent", "module":"memory"}, "Available MBytes Swap" : {"plugin":"swap", "field":"free", "module":"memory"}, "% Available Swap Space" : {"plugin":"swap", "field":"free_percent", "module":"memory"}, #Need to calculate percentage "Used MBytes Swap Space" : {"plugin":"swap", "field":"used", "module":"memory"}, "% Used Swap Space" : {"plugin":"swap", "field":"used_percent", "module":"memory"}, "Page Reads/sec": {"plugin":"kernel_vmstat", "field":"pgpgin", "op":"rate", "module":"memory"}, "Page Writes/sec" : {"plugin":"kernel_vmstat", "field":"pgpgout", "op":"rate", "module":"memory"}, "Pages/sec" : {"plugin":"kernel_vmstat", "field":"total_pages", "op":"rate", "module":"memory"}, # VM Insights # 8 slashes because this goes from JSON -> Python -> Telegraf config -> Go -> C++ and each level does an escape "Memory\\AvailableMB" : {"plugin":"mem_vmi", "field":"Memory\\\\\\\\AvailableMB", "module":"memory"}, "Memory\\AvailablePercentage" : {"plugin":"mem_vmi", "field":"Memory\\\\\\\\AvailablePercentage", "module":"memory"}, #OMI Filesystem plugin "filesystem->filesystem used space" : {"plugin":"disk", "field":"used", "ladtablekey":"/builtin/filesystem/usedspace"}, "filesystem->filesystem % used space" : {"plugin":"disk", "field":"used_percent", "ladtablekey":"/builtin/filesystem/percentusedspace"}, "filesystem->filesystem free space" : {"plugin":"disk", "field":"free", "ladtablekey":"/builtin/filesystem/freespace"}, "filesystem->filesystem % free space" : {"plugin":"disk", "field":"free_percent", "ladtablekey":"/builtin/filesystem/percentfreespace"}, #Need to calculate percentage "filesystem->filesystem % free inodes" : {"plugin":"disk", "field":"inodes_free_percent", "ladtablekey":"/builtin/filesystem/percentfreeinodes"}, #Need to calculate percentage "filesystem->filesystem % used inodes" : {"plugin":"disk", "field":"inodes_used_percent", "ladtablekey":"/builtin/filesystem/percentusedinodes"}, #Need to calculate percentage "filesystem->filesystem transfers/sec" : {"plugin":"diskio", "field":"total_transfers_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/transferspersecond"}, #Need to calculate sum "filesystem->filesystem read bytes/sec" : {"plugin":"diskio", "field":"read_bytes_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/bytesreadpersecond"}, #Need to calculate rate (but each second not each interval) "filesystem->filesystem bytes/sec" : {"plugin":"diskio", "field":"total_bytes_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/bytespersecond"}, #Need to calculate rate and then sum "filesystem->filesystem write bytes/sec" : {"plugin":"diskio", "field":"write_bytes_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/byteswrittenpersecond"}, #Need to calculate rate (but each second not each interval) "filesystem->filesystem reads/sec" : {"plugin":"diskio", "field":"reads_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/readspersecond"}, #Need to calculate rate (but each second not each interval) "filesystem->filesystem writes/sec" : {"plugin":"diskio", "field":"writes_filesystem", "op":"rate", "ladtablekey":"/builtin/filesystem/writespersecond"}, #Need to calculate rate (but each second not each interval) "% Used Space" : {"plugin":"disk", "field":"used_percent", "module":"filesystem"}, "Free Megabytes" : {"plugin":"disk", "field":"free", "module":"filesystem"}, "% Free Space" : {"plugin":"disk", "field":"free_percent", "module":"filesystem"}, #Need to calculate percentage "% Free Inodes" : {"plugin":"disk", "field":"inodes_free_percent", "module":"filesystem"}, #Need to calculate percentage "% Used Inodes" : {"plugin":"disk", "field":"inodes_used_percent", "module":"filesystem"}, #Need to calculate percentage "Disk Transfers/sec" : {"plugin":"diskio", "field":"total_transfers", "op":"rate", "module":"filesystem"}, #Need to calculate sum "Disk Read Bytes/sec" : {"plugin":"diskio", "field":"read_bytes", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "Logical Disk Bytes/sec" : {"plugin":"diskio", "field":"total_bytes", "op":"rate", "module":"filesystem"}, #Need to calculate rate and then sum "Disk Write Bytes/sec" : {"plugin":"diskio", "field":"write_bytes", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "Disk Reads/sec" : {"plugin":"diskio", "field":"reads", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "Disk Writes/sec" : {"plugin":"diskio", "field":"writes", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) # VM Insights # 8 slashes because this goes from JSON -> Python -> Telegraf config -> Go -> C++ and each level does an escape "LogicalDisk\\FreeSpaceMB" : {"plugin":"disk_vmi", "field":"LogicalDisk\\\\\\\\FreeSpaceMB", "module":"filesystem"}, "LogicalDisk\\FreeSpacePercentage" : {"plugin":"disk_vmi", "field":"LogicalDisk\\\\\\\\FreeSpacePercentage", "module":"filesystem"}, #Need to calculate percentage "LogicalDisk\\Status" : {"plugin":"disk_vmi", "field":"LogicalDisk\\\\\\\\Status", "module":"filesystem"}, #Need to calculate percentage #"LogicalDisk\BytesPerSecond", "LogicalDisk\ReadBytesPerSecond", "LogicalDisk\ReadsPerSecond", "LogicalDisk\WriteBytesPerSecond", "LogicalDisk\WritesPerSecond", "LogicalDisk\TransfersPerSecond", "LogicalDisk\\TransfersPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\TransfersPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate sum "LogicalDisk\\ReadBytesPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\ReadBytesPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "LogicalDisk\\BytesPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\BytesPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate rate and then sum "LogicalDisk\\WriteBytesPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\WriteBytesPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "LogicalDisk\\ReadsPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\ReadsPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) "LogicalDisk\\WritesPerSecond" : {"plugin":"diskio_vmi", "field":"LogicalDisk\\\\\\\\WritesPerSecond", "op":"rate", "module":"filesystem"}, #Need to calculate rate (but each second not each interval) # Process plugin "Pct User Time" : {"plugin":"procstat", "field":"cpu_time_user", "module":"process"}, "Pct Privileged Time" : {"plugin":"procstat", "field":"cpu_time_system", "module":"process"}, "Used Memory" : {"plugin":"procstat", "field":"memory_rss", "module":"process"}, "Virtual Shared Memory" : {"plugin":"procstat", "field":"memory_vms", "module":"process"}, # System plugin "Uptime" : {"plugin":"system", "field":"uptime", "module":"system"}, "Load1" : {"plugin":"system", "field":"load1", "module":"system"}, "Load5" : {"plugin":"system", "field":"load5", "module":"system"}, "Load15" : {"plugin":"system", "field":"load15", "module":"system"}, "Users" : {"plugin":"system", "field":"n_users", "module":"system"}, "CPUs" : {"plugin":"system", "field":"n_cpus", "module":"system"}, "Unique Users" : {"plugin":"system", "field":"n_unique_users", "module":"system"}, # #OMI Disk plugin "disk->disk read guest os" : {"plugin":"diskio", "field":"read_bytes", "op":"rate", "ladtablekey":"/builtin/disk/readbytespersecond"}, "disk->disk write guest os" : {"plugin":"diskio", "field":"write_bytes", "op":"rate", "ladtablekey":"/builtin/disk/writebytespersecond"}, "disk->disk total bytes" : {"plugin":"diskio", "field":"total_bytes", "op":"rate", "ladtablekey":"/builtin/disk/bytespersecond"}, "disk->disk reads" : {"plugin":"diskio", "field":"reads", "op":"rate", "ladtablekey":"/builtin/disk/readspersecond"}, #Need to calculate rate (but each second not each interval) "disk->disk writes" : {"plugin":"diskio", "field":"writes", "op":"rate", "ladtablekey":"/builtin/disk/writespersecond"}, "disk->disk transfers" : {"plugin":"diskio", "field":"total_transfers", "op":"rate", "ladtablekey":"/builtin/disk/transferspersecond"}, "disk->disk read time" : {"plugin":"diskio", "field":"read_time", "op":"rate", "ladtablekey":"/builtin/disk/averagereadtime"}, "disk->disk write time" : {"plugin":"diskio", "field":"write_time", "op":"rate", "ladtablekey":"/builtin/disk/averagewritetime"}, "disk->disk transfer time" : {"plugin":"diskio", "field":"io_time", "op":"rate", "ladtablekey":"/builtin/disk/averagetransfertime"}, "disk->disk queue length" : {"plugin":"diskio", "field":"iops_in_progress", "ladtablekey":"/builtin/disk/averagediskqueuelength"} ##### These are the counter keys and telegraf plugins for Azure Monitor Agent } ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2016 Microsoft Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ default: clean init build EXTENSIONS = \ CustomScript \ DSC \ OSPatching \ VMBackup clean: rm -rf build init: @mkdir -p build build: init $(EXTENSIONS) buildVMAccess define make-extension-zip $(eval NAME = $(shell grep -Pom1 "(?<=)[^<]+" $@/manifest.xml)) $(eval VERSION = $(shell grep -Pom1 "(?<=)[^<]+" $@/manifest.xml)) @echo "Building '$(NAME)-$(VERSION).zip' ..." @cd $@ && find . -type f | grep -v "/test/" | grep -v "./references" | zip -9 -@ ../build/$(NAME)-$(VERSION).zip > /dev/null @find ./Utils -type f | grep -v "/test/" | zip -9 -@ build/$(NAME)-$(VERSION).zip > /dev/null endef $(EXTENSIONS): $(make-extension-zip) @cd Common/ && echo ./waagentloader.py | zip -9 -@ ../build/$(NAME)-$(VERSION).zip > /dev/null @cd Common/WALinuxAgent-2.0.16 && echo ./waagent | zip -9 -@ ../../build/$(NAME)-$(VERSION).zip > /dev/null buildVMAccess: $(eval NAME = $(shell grep -Pom1 "(?<=)[^<]+" VMAccess/manifest.xml)) $(eval VERSION = $(shell grep -Pom1 "(?<=)[^<]+" VMAccess/manifest.xml)) @echo "Building '$(NAME)-$(VERSION).zip' ..." @cd VMAccess && find . -type f | grep -v "/test/" | grep -v "./references" | zip -9 -@ ../build/$(NAME)-$(VERSION).zip > /dev/null @zip -9 build/$(NAME)-$(VERSION).zip ./Utils/__init__.py ./Utils/constants.py ./Utils/distroutils.py\ ./Utils/extensionutils.py ./Utils/handlerutil2.py ./Utils/logger.py ./Utils/ovfutils.py > /dev/null .PHONY: clean build $(EXTENSIONS) buildVMAccess ================================================ FILE: OSPatching/HandlerManifest.json ================================================ [ { "version": 1.0, "handlerManifest": { "disableCommand": "./handler.py -disable", "enableCommand": "./handler.py -enable", "installCommand": "./handler.py -install", "uninstallCommand": "./handler.py -uninstall", "updateCommand": "./handler.py -update", "rebootAfterInstall": false, "reportHeartbeat": false } } ] ================================================ FILE: OSPatching/README.md ================================================ # :warning: IMPORTANT :warning: **The OSPatching extension for Linux is deprecated.** OSPatchingForLinux is deprecated and will be retired February 2018. Your Linux distro has well supported and maintained ways to enable automatic updates for your VMs to include VMs you use in Production environments. It is recommended that you consult your distro's best practices for automatic updates. ## Linux Distributions - Ubuntu - See the [unattended-upgrades](https://help.ubuntu.com/lts/serverguide/automatic-updates.html) package documentation - CentOS and RHEL - See the manpage of `yum-cron` for the auto-update mechanism documentation # OSPatching Extension Allows the owner of the Azure VM to configure a Linux VM patching schedule cycle or perform OS patching on-demand as a one-time task. The actual patching operation is scheduled as a cron job. Lastest version is 2.3. You can read the User Guide, [Automate Linux VM OS Updates Using OSPatching Extension (outdated, needs to update)](http://azure.microsoft.com/blog/2014/10/23/automate-linux-vm-os-updates-using-ospatching-extension/). OSPatching Extension can: * Patch the OS automatically as a scheduled task * Patch the OS as a one-time task * The patching can be stopped before the actual patching operation begins * The status of VM can be checked by user-defined scripts stored locally, in GitHub, or in Azure Storage # User Guide ## 1. Configuration schema All settings are set in the protected configuration. No settings are available in the public configuration and it can be omitted. ### 1.1. Protected configuration Schema for the protected configuration file. | Name | Description | Value Type | Default Value | |:---|:---|:---|:---| | disabled | Flag to disable this extension | required, boolean | false | | stop | Flag to cancel the OS update process | required, boolean | false | | rebootAfterPatch | The reboot behavior after patching | optional, string | RebootIfNeed | | category | Type of patches to install | optional, string | Important | | installDuration | The allowed total time for installation | optional, string | 01:00 | | oneoff | Patch the OS immediately | optional, boolean | false | | intervalOfWeeks | The update frequency (in weeks) | optional, string | 1 | | dayOfWeek | The patching date (of the week)You can specify multiple days in a week | optional, string | Everyday | | startTime | Start time of patching | optional, string | 03:00 | | distUpgradeList | Path to a repo list which for which a full upgrade (e.g. dist-upgrade in Ubuntu) will occur | optional, string | /etc/apt/sources.list.d/custom.list | | distUpgradeAll | Flag to enable full upgrade (e.g. dist-upgrade in Ubuntu) for all repos/packages. Disabled (False) by default | optional, bool | True | | vmStatusTest | Including `local`, `idleTestScript` and `healthyTestScript` | optional, object | | | local | Flag to assign the location of user-defined scripts | optional, boolean | false | | idleTestScript | If `local` is true, it is the contents of the idle test script. Otherwise, it is the uri of the idle test script. | optional, string | | | healthyTestScript | If `local` is true, it is the contents of the healthy test script. Otherwise, it is the uri of the healthy test script. | optional, string | | | storageAccountName | The name of the storage account | optional, string | | | storageAccountKey | The access key of the storage account | optional, string | | If the vmStatusTest scripts are stored in the private Azure Storage, you must provide `storageAccountName` and `storageAccountKey`. You can get these two values from Azure Portal. ```json { "disabled": false, "stop": false, "rebootAfterPatch": "RebootIfNeed|Required|NotRequired|Auto", "category": "Important|ImportantAndRecommended", "installDuration": "", "oneoff": false, "intervalOfWeeks": "", "dayOfWeek": "Sunday|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Everyday", "startTime": "", "distUpgradeList": "", "vmStatusTest": { "local": false, "idleTestScript": "", "healthyTestScript": "" }, "storageAccountName": "", "storageAccountKey": "" } ``` ## 2. Deploying the Extension to a VM You can deploy it using Azure CLI, Azure Powershell and ARM template. > **NOTE:** Creating VM in Azure has two deployment model: Classic and [Resource Manager][arm-overview]. In diffrent models, the deploying commands have different syntaxes. Please select the right one in section 2.1 and 2.2 below. ### 2.1. Using [**Azure CLI**][azure-cli] Before deploying OSPatching Extension, you should configure your `protected.json` (in section 1.1 above). #### 2.1.1 Classic The Classic mode is also called Azure Service Management mode. You can change to it by running: ``` $ azure config mode asm ``` You can deploying OSPatching Extension by running: ``` $ azure vm extension set \ OSPatchingForLinux Microsoft.OSTCExtensions \ --private-config-path protected.json ``` In the command above, you can change version with `"*"` to use latest version available, or `"2.*"` to get newest version that does not introduce non- breaking schema changes. To find the latest version available, run: ``` $ azure vm extension list ``` #### 2.1.2 Resource Manager You can change to Azure Resource Manager mode by running: ``` $ azure config mode arm ``` You can deploy OSPatching Extension by running: ``` $ azure vm extension set \ OSPatchingForLinux Microsoft.OSTCExtensions \ --private-config-path protected.json ``` > **NOTE:** In ARM mode, `azure vm extension list` is not available for now. ### 2.2. Using [**Azure Powershell**][azure-powershell] #### 2.2.1 Classic You can login to your Azure account (Azure Service Management mode) by running: ```powershell Add-AzureAccount ``` You can deploying OSPatching Extension by running: ```powershell $VmName = '' $vm = Get-AzureVM -ServiceName $VmName -Name $VmName $ExtensionName = 'OSPatchingForLinux' $Publisher = 'Microsoft.OSTCExtensions' $Version = '' $idleTestScriptUri = '' $healthyTestScriptUri = '' $PrivateConfig = ConvertTo-Json -InputObject @{ "disabled" = $false; "stop" = $true|$false; "rebootAfterPatch" = "RebootIfNeed|Required|NotRequired|Auto"; "category" = "Important|ImportantAndRecommended"; "installDuration" = ""; "oneoff" = $true|$false; "intervalOfWeeks" = ""; "dayOfWeek" = "Sunday|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Everyday"; "startTime" = ""; "vmStatusTest" = (@{ "local" = $false; "idleTestScript" = $idleTestScriptUri; "healthyTestScript" = $healthyTestScriptUri }); "storageAccountName" = ""; "storageAccountKey" = "" } Set-AzureVMExtension -ExtensionName $ExtensionName -VM $vm ` -Publisher $Publisher -Version $Version ` -PrivateConfiguration $PrivateConfig | Update-AzureVM ``` #### 2.2.2 Resource Manager You can login to your Azure account (Azure Resource Manager mode) by running: ```powershell Login-AzureRmAccount ``` Click [**HERE**](https://azure.microsoft.com/en-us/documentation/articles/powershell-azure-resource-manager/) to learn more about how to use Azure PowerShell with Azure Resource Manager. You can deploying OSPatching Extension by running: ```powershell $RGName = '' $VmName = '' $Location = '' $ExtensionName = 'OSPatchingForLinux' $Publisher = 'Microsoft.OSTCExtensions' $Version = '' $PrivateConf = ConvertTo-Json -InputObject @{ "disabled" = $false; "stop" = $true|$false; "rebootAfterPatch" = "RebootIfNeed|Required|NotRequired|Auto"; "category" = "Important|ImportantAndRecommended"; "installDuration" = ""; "oneoff" = $true|$false; "intervalOfWeeks" = ""; "dayOfWeek" = "Sunday|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Everyday"; "startTime" = ""; "vmStatusTest" = (@{ "local" = $false; "idleTestScript" = $idleTestScriptUri; "healthyTestScript" = $healthyTestScriptUri }); "storageAccountName" = ""; "storageAccountKey" = "" } Set-AzureRmVMExtension -ResourceGroupName $RGName -VMName $VmName -Location $Location ` -Name $ExtensionName -Publisher $Publisher -ExtensionType $ExtensionName ` -TypeHandlerVersion $Version -ProtectedSettingString $PrivateConf ``` ### 2.3. Using [**ARM Template**][arm-template] ```json { "type": "Microsoft.Compute/virtualMachines/extensions", "name": "", "apiVersion": "", "location": "", "dependsOn": [ "[concat('Microsoft.Compute/virtualMachines/', )]" ], "properties": { "publisher": "Microsoft.OSTCExtensions", "type": "OSPatchingForLinux", "typeHandlerVersion": "2.0", "protectedSettings": { "disabled": false, "stop": false, "rebootAfterPatch": "RebootIfNeed|Required|NotRequired|Auto", "category": "Important|ImportantAndRecommended", "installDuration": "", "oneoff": false, "intervalOfWeeks": "", "dayOfWeek": "Sunday|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Everyday", "startTime": "", "vmStatusTest": { "local": false, "idleTestScript": "", "healthyTestScript": "" }, "storageAccountName": "", "storageAccountKey": "" } } } ``` The sample ARM template is [201-ospatching-extension-on-ubuntu](https://github.com/Azure/azure-quickstart-templates/tree/master/201-ospatching-extension-on-ubuntu). For more details about ARM template, please visit [Authoring Azure Resource Manager templates](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authoring-templates/). ## 3. Scenarios ### 3.1 Setting up regularly scheduled patching **Protected Settings** ```json { "disabled": false, "stop": false, "rebootAfterPatch": "RebootIfNeed", "intervalOfWeeks": "1", "dayOfWeek": "Sunday|Wednesday", "startTime": "03:00", "category": "ImportantAndRecommended", "installDuration": "00:30" } ``` ### 3.2 Setting up one-off patching **Protected Settings** ```json { "disabled": false, "stop": false, "rebootAfterPatch": "RebootIfNeed", "oneoff": true, "category": "ImportantAndRecommended", "installDuration": "00:30" } ``` ### 3.3 Stop the running patching You can stop the OS updates to debug issues. Once the `stop` parameter is set to `true`, the OS update will stop after the current update is finished. **Protected Settings** ```json { "disabled": false, "stop": true } ``` ### 3.4 Test the idle before patching and the health after patching If the `vmStatusTest` scripts are stored in Azure Storage private containers, you have to provide the `storageAccountName` and `storageAccountKey`. **Protected Settings** ```json { "disabled": false, "stop": false, "rebootAfterPatch": "RebootIfNeed", "category": "ImportantAndRecommended", "installDuration": "00:30", "oneoff": false, "intervalOfWeeks": "1", "dayOfWeek": "Sunday|Wednesday", "startTime": "03:00", "vmStatusTest": { "local": false, "idleTestScript": "", "healthyTestScript": "" }, "storageAccountName": "MyAccount", "storageAccountKey": "Mykey" } ``` ### 3.5 Enable the extension repeatedly Enabling the OSPatching Extension with the exact same configuration is unsupported and will result in a no-op (nothing will happen). If you need to run scripts repeatedly, you can add a timestamp. ```json "timestamp": 123456789 ``` ### 3.6 Disable the extension If you want to switch to manual OS update temporarily, you can set the `disable` parameter to `true` instead of uninstalling the OSPatching Extension. ## Debugging * The operation log of the extension is `/var/log/azure///extension.log` file. * The installation status of the extension is reported back to Azure so that the user can see the status on Azure Portal. This does not mean the OSPatching Extension successfully applied the current configuration to the VM. * Attempting to enable the OSPatching Extension 2 or more times with the same configuration will result in nothing happening. See [Enable the extension repeatedly](#3.5 Enable the extension repeatedly) section above for more details. # Known Issues * If the scheduled task does not run on some RedHat distros, there may be a selinux-policy problem. Please refer to [https://bugzilla.redhat.com/show\_bug.cgi?id=657104](https://bugzilla.redhat.com/show_bug.cgi?id=657104) [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: OSPatching/azure/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import base64 import hashlib import hmac import sys import types import warnings import inspect if sys.version_info < (3,): from urllib2 import quote as url_quote from urllib2 import unquote as url_unquote _strtype = basestring else: from urllib.parse import quote as url_quote from urllib.parse import unquote as url_unquote _strtype = str from datetime import datetime from xml.dom import minidom from xml.sax.saxutils import escape as xml_escape #-------------------------------------------------------------------------- # constants __author__ = 'Microsoft Corp. ' __version__ = '0.8.4' # Live ServiceClient URLs BLOB_SERVICE_HOST_BASE = '.blob.core.windows.net' QUEUE_SERVICE_HOST_BASE = '.queue.core.windows.net' TABLE_SERVICE_HOST_BASE = '.table.core.windows.net' SERVICE_BUS_HOST_BASE = '.servicebus.windows.net' MANAGEMENT_HOST = 'management.core.windows.net' # Development ServiceClient URLs DEV_BLOB_HOST = '127.0.0.1:10000' DEV_QUEUE_HOST = '127.0.0.1:10001' DEV_TABLE_HOST = '127.0.0.1:10002' # Default credentials for Development Storage Service DEV_ACCOUNT_NAME = 'devstoreaccount1' DEV_ACCOUNT_KEY = 'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==' # All of our error messages _ERROR_CANNOT_FIND_PARTITION_KEY = 'Cannot find partition key in request.' _ERROR_CANNOT_FIND_ROW_KEY = 'Cannot find row key in request.' _ERROR_INCORRECT_TABLE_IN_BATCH = \ 'Table should be the same in a batch operations' _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH = \ 'Partition Key should be the same in a batch operations' _ERROR_DUPLICATE_ROW_KEY_IN_BATCH = \ 'Row Keys should not be the same in a batch operations' _ERROR_BATCH_COMMIT_FAIL = 'Batch Commit Fail' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE = \ 'Message is not peek locked and cannot be deleted.' _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK = \ 'Message is not peek locked and cannot be unlocked.' _ERROR_QUEUE_NOT_FOUND = 'Queue was not found' _ERROR_TOPIC_NOT_FOUND = 'Topic was not found' _ERROR_CONFLICT = 'Conflict ({0})' _ERROR_NOT_FOUND = 'Not found ({0})' _ERROR_UNKNOWN = 'Unknown error ({0})' _ERROR_SERVICEBUS_MISSING_INFO = \ 'You need to provide servicebus namespace, access key and Issuer' _ERROR_STORAGE_MISSING_INFO = \ 'You need to provide both account name and access key' _ERROR_ACCESS_POLICY = \ 'share_access_policy must be either SignedIdentifier or AccessPolicy ' + \ 'instance' _WARNING_VALUE_SHOULD_BE_BYTES = \ 'Warning: {0} must be bytes data type. It will be converted ' + \ 'automatically, with utf-8 text encoding.' _ERROR_VALUE_SHOULD_BE_BYTES = '{0} should be of type bytes.' _ERROR_VALUE_NONE = '{0} should not be None.' _ERROR_VALUE_NEGATIVE = '{0} should not be negative.' _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY = \ 'Cannot serialize the specified value ({0}) to an entity. Please use ' + \ 'an EntityProperty (which can specify custom types), int, str, bool, ' + \ 'or datetime.' _ERROR_PAGE_BLOB_SIZE_ALIGNMENT = \ 'Invalid page blob size: {0}. ' + \ 'The size must be aligned to a 512-byte boundary.' _USER_AGENT_STRING = 'pyazure/' + __version__ METADATA_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices/metadata' class WindowsAzureData(object): ''' This is the base of data class. It is only used to check whether it is instance or not. ''' pass class WindowsAzureError(Exception): ''' WindowsAzure Excpetion base class. ''' def __init__(self, message): super(WindowsAzureError, self).__init__(message) class WindowsAzureConflictError(WindowsAzureError): '''Indicates that the resource could not be created because it already exists''' def __init__(self, message): super(WindowsAzureConflictError, self).__init__(message) class WindowsAzureMissingResourceError(WindowsAzureError): '''Indicates that a request for a request for a resource (queue, table, container, etc...) failed because the specified resource does not exist''' def __init__(self, message): super(WindowsAzureMissingResourceError, self).__init__(message) class WindowsAzureBatchOperationError(WindowsAzureError): '''Indicates that a batch operation failed''' def __init__(self, message, code): super(WindowsAzureBatchOperationError, self).__init__(message) self.code = code class Feed(object): pass class _Base64String(str): pass class HeaderDict(dict): def __getitem__(self, index): return super(HeaderDict, self).__getitem__(index.lower()) def _encode_base64(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') encoded = base64.b64encode(data) return encoded.decode('utf-8') def _decode_base64_to_bytes(data): if isinstance(data, _unicode_type): data = data.encode('utf-8') return base64.b64decode(data) def _decode_base64_to_text(data): decoded_bytes = _decode_base64_to_bytes(data) return decoded_bytes.decode('utf-8') def _get_readable_id(id_name, id_prefix_to_skip): """simplified an id to be more friendly for us people""" # id_name is in the form 'https://namespace.host.suffix/name' # where name may contain a forward slash! pos = id_name.find('//') if pos != -1: pos += 2 if id_prefix_to_skip: pos = id_name.find(id_prefix_to_skip, pos) if pos != -1: pos += len(id_prefix_to_skip) pos = id_name.find('/', pos) if pos != -1: return id_name[pos + 1:] return id_name def _get_entry_properties_from_node(entry, include_id, id_prefix_to_skip=None, use_title_as_id=False): ''' get properties from entry xml ''' properties = {} etag = entry.getAttributeNS(METADATA_NS, 'etag') if etag: properties['etag'] = etag for updated in _get_child_nodes(entry, 'updated'): properties['updated'] = updated.firstChild.nodeValue for name in _get_children_from_path(entry, 'author', 'name'): if name.firstChild is not None: properties['author'] = name.firstChild.nodeValue if include_id: if use_title_as_id: for title in _get_child_nodes(entry, 'title'): properties['name'] = title.firstChild.nodeValue else: for id in _get_child_nodes(entry, 'id'): properties['name'] = _get_readable_id( id.firstChild.nodeValue, id_prefix_to_skip) return properties def _get_entry_properties(xmlstr, include_id, id_prefix_to_skip=None): ''' get properties from entry xml ''' xmldoc = minidom.parseString(xmlstr) properties = {} for entry in _get_child_nodes(xmldoc, 'entry'): properties.update(_get_entry_properties_from_node(entry, include_id, id_prefix_to_skip)) return properties def _get_first_child_node_value(parent_node, node_name): xml_attrs = _get_child_nodes(parent_node, node_name) if xml_attrs: xml_attr = xml_attrs[0] if xml_attr.firstChild: value = xml_attr.firstChild.nodeValue return value def _get_child_nodes(node, tagName): return [childNode for childNode in node.getElementsByTagName(tagName) if childNode.parentNode == node] def _get_children_from_path(node, *path): '''descends through a hierarchy of nodes returning the list of children at the inner most level. Only returns children who share a common parent, not cousins.''' cur = node for index, child in enumerate(path): if isinstance(child, _strtype): next = _get_child_nodes(cur, child) else: next = _get_child_nodesNS(cur, *child) if index == len(path) - 1: return next elif not next: break cur = next[0] return [] def _get_child_nodesNS(node, ns, tagName): return [childNode for childNode in node.getElementsByTagNameNS(ns, tagName) if childNode.parentNode == node] def _create_entry(entry_body): ''' Adds common part of entry to a given entry body and return the whole xml. ''' updated_str = datetime.utcnow().isoformat() if datetime.utcnow().utcoffset() is None: updated_str += '+00:00' entry_start = ''' <updated>{updated}</updated><author><name /></author><id /> <content type="application/xml"> {body}</content></entry>''' return entry_start.format(updated=updated_str, body=entry_body) def _to_datetime(strtime): return datetime.strptime(strtime, "%Y-%m-%dT%H:%M:%S.%f") _KNOWN_SERIALIZATION_XFORMS = { 'include_apis': 'IncludeAPIs', 'message_id': 'MessageId', 'content_md5': 'Content-MD5', 'last_modified': 'Last-Modified', 'cache_control': 'Cache-Control', 'account_admin_live_email_id': 'AccountAdminLiveEmailId', 'service_admin_live_email_id': 'ServiceAdminLiveEmailId', 'subscription_id': 'SubscriptionID', 'fqdn': 'FQDN', 'private_id': 'PrivateID', 'os_virtual_hard_disk': 'OSVirtualHardDisk', 'logical_disk_size_in_gb': 'LogicalDiskSizeInGB', 'logical_size_in_gb': 'LogicalSizeInGB', 'os': 'OS', 'persistent_vm_downtime_info': 'PersistentVMDowntimeInfo', 'copy_id': 'CopyId', } def _get_serialization_name(element_name): """converts a Python name into a serializable name""" known = _KNOWN_SERIALIZATION_XFORMS.get(element_name) if known is not None: return known if element_name.startswith('x_ms_'): return element_name.replace('_', '-') if element_name.endswith('_id'): element_name = element_name.replace('_id', 'ID') for name in ['content_', 'last_modified', 'if_', 'cache_control']: if element_name.startswith(name): element_name = element_name.replace('_', '-_') return ''.join(name.capitalize() for name in element_name.split('_')) if sys.version_info < (3,): _unicode_type = unicode def _str(value): if isinstance(value, unicode): return value.encode('utf-8') return str(value) else: _str = str _unicode_type = str def _str_or_none(value): if value is None: return None return _str(value) def _int_or_none(value): if value is None: return None return str(int(value)) def _bool_or_none(value): if value is None: return None if isinstance(value, bool): if value: return 'true' else: return 'false' return str(value) def _convert_class_to_xml(source, xml_prefix=True): if source is None: return '' xmlstr = '' if xml_prefix: xmlstr = '<?xml version="1.0" encoding="utf-8"?>' if isinstance(source, list): for value in source: xmlstr += _convert_class_to_xml(value, False) elif isinstance(source, WindowsAzureData): class_name = source.__class__.__name__ xmlstr += '<' + class_name + '>' for name, value in vars(source).items(): if value is not None: if isinstance(value, list) or \ isinstance(value, WindowsAzureData): xmlstr += _convert_class_to_xml(value, False) else: xmlstr += ('<' + _get_serialization_name(name) + '>' + xml_escape(str(value)) + '</' + _get_serialization_name(name) + '>') xmlstr += '</' + class_name + '>' return xmlstr def _find_namespaces_from_child(parent, child, namespaces): """Recursively searches from the parent to the child, gathering all the applicable namespaces along the way""" for cur_child in parent.childNodes: if cur_child is child: return True if _find_namespaces_from_child(cur_child, child, namespaces): # we are the parent node for key in cur_child.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': namespaces[key] = cur_child.attributes[key] break return False def _find_namespaces(parent, child): res = {} for key in parent.documentElement.attributes.keys(): if key.startswith('xmlns:') or key == 'xmlns': res[key] = parent.documentElement.attributes[key] _find_namespaces_from_child(parent, child, res) return res def _clone_node_with_namespaces(node_to_clone, original_doc): clone = node_to_clone.cloneNode(True) for key, value in _find_namespaces(original_doc, node_to_clone).items(): clone.attributes[key] = value return clone def _convert_response_to_feeds(response, convert_callback): if response is None: return None feeds = _list_of(Feed) x_ms_continuation = HeaderDict() for name, value in response.headers: if 'x-ms-continuation' in name: x_ms_continuation[name[len('x-ms-continuation') + 1:]] = value if x_ms_continuation: setattr(feeds, 'x_ms_continuation', x_ms_continuation) xmldoc = minidom.parseString(response.body) xml_entries = _get_children_from_path(xmldoc, 'feed', 'entry') if not xml_entries: # in some cases, response contains only entry but no feed xml_entries = _get_children_from_path(xmldoc, 'entry') if inspect.isclass(convert_callback) and issubclass(convert_callback, WindowsAzureData): for xml_entry in xml_entries: return_obj = convert_callback() for node in _get_children_from_path(xml_entry, 'content', convert_callback.__name__): _fill_data_to_return_object(node, return_obj) for name, value in _get_entry_properties_from_node(xml_entry, include_id=True, use_title_as_id=True).items(): setattr(return_obj, name, value) feeds.append(return_obj) else: for xml_entry in xml_entries: new_node = _clone_node_with_namespaces(xml_entry, xmldoc) feeds.append(convert_callback(new_node.toxml('utf-8'))) return feeds def _validate_type_bytes(param_name, param): if not isinstance(param, bytes): raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _validate_not_none(param_name, param): if param is None: raise TypeError(_ERROR_VALUE_NONE.format(param_name)) def _fill_list_of(xmldoc, element_type, xml_element_name): xmlelements = _get_child_nodes(xmldoc, xml_element_name) return [_parse_response_body_from_xml_node(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_scalar_list_of(xmldoc, element_type, parent_xml_element_name, xml_element_name): '''Converts an xml fragment into a list of scalar types. The parent xml element contains a flat list of xml elements which are converted into the specified scalar type and added to the list. Example: xmldoc= <Endpoints> <Endpoint>http://{storage-service-name}.blob.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.queue.core.windows.net/</Endpoint> <Endpoint>http://{storage-service-name}.table.core.windows.net/</Endpoint> </Endpoints> element_type=str parent_xml_element_name='Endpoints' xml_element_name='Endpoint' ''' xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], xml_element_name) return [_get_node_value(xmlelement, element_type) \ for xmlelement in xmlelements] def _fill_dict(xmldoc, element_name): xmlelements = _get_child_nodes(xmldoc, element_name) if xmlelements: return_obj = {} for child in xmlelements[0].childNodes: if child.firstChild: return_obj[child.nodeName] = child.firstChild.nodeValue return return_obj def _fill_dict_of(xmldoc, parent_xml_element_name, pair_xml_element_name, key_xml_element_name, value_xml_element_name): '''Converts an xml fragment into a dictionary. The parent xml element contains a list of xml elements where each element has a child element for the key, and another for the value. Example: xmldoc= <ExtendedProperties> <ExtendedProperty> <Name>Ext1</Name> <Value>Val1</Value> </ExtendedProperty> <ExtendedProperty> <Name>Ext2</Name> <Value>Val2</Value> </ExtendedProperty> </ExtendedProperties> element_type=str parent_xml_element_name='ExtendedProperties' pair_xml_element_name='ExtendedProperty' key_xml_element_name='Name' value_xml_element_name='Value' ''' return_obj = {} xmlelements = _get_child_nodes(xmldoc, parent_xml_element_name) if xmlelements: xmlelements = _get_child_nodes(xmlelements[0], pair_xml_element_name) for pair in xmlelements: keys = _get_child_nodes(pair, key_xml_element_name) values = _get_child_nodes(pair, value_xml_element_name) if keys and values: key = keys[0].firstChild.nodeValue value = values[0].firstChild.nodeValue return_obj[key] = value return return_obj def _fill_instance_child(xmldoc, element_name, return_type): '''Converts a child of the current dom element to the specified type. ''' xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements: return None return_obj = return_type() _fill_data_to_return_object(xmlelements[0], return_obj) return return_obj def _fill_instance_element(element, return_type): """Converts a DOM element into the specified object""" return _parse_response_body_from_xml_node(element, return_type) def _fill_data_minidom(xmldoc, element_name, data_member): xmlelements = _get_child_nodes( xmldoc, _get_serialization_name(element_name)) if not xmlelements or not xmlelements[0].childNodes: return None value = xmlelements[0].firstChild.nodeValue if data_member is None: return value elif isinstance(data_member, datetime): return _to_datetime(value) elif type(data_member) is bool: return value.lower() != 'false' else: return type(data_member)(value) def _get_node_value(xmlelement, data_type): value = xmlelement.firstChild.nodeValue if data_type is datetime: return _to_datetime(value) elif data_type is bool: return value.lower() != 'false' else: return data_type(value) def _get_request_body_bytes_only(param_name, param_value): '''Validates the request body passed in and converts it to bytes if our policy allows it.''' if param_value is None: return b'' if isinstance(param_value, bytes): return param_value # Previous versions of the SDK allowed data types other than bytes to be # passed in, and they would be auto-converted to bytes. We preserve this # behavior when running under 2.7, but issue a warning. # Python 3 support is new, so we reject anything that's not bytes. if sys.version_info < (3,): warnings.warn(_WARNING_VALUE_SHOULD_BE_BYTES.format(param_name)) return _get_request_body(param_value) raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) def _get_request_body(request_body): '''Converts an object into a request body. If it's None we'll return an empty string, if it's one of our objects it'll convert it to XML and return it. Otherwise we just use the object directly''' if request_body is None: return b'' if isinstance(request_body, WindowsAzureData): request_body = _convert_class_to_xml(request_body) if isinstance(request_body, bytes): return request_body if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') request_body = str(request_body) if isinstance(request_body, _unicode_type): return request_body.encode('utf-8') return request_body def _parse_enum_results_list(response, return_type, resp_type, item_type): """resp_body is the XML we received resp_type is a string, such as Containers, return_type is the type we're constructing, such as ContainerEnumResults item_type is the type object of the item to be created, such as Container This function then returns a ContainerEnumResults object with the containers member populated with the results. """ # parsing something like: # <EnumerationResults ... > # <Queues> # <Queue> # <Something /> # <SomethingElse /> # </Queue> # </Queues> # </EnumerationResults> respbody = response.body return_obj = return_type() doc = minidom.parseString(respbody) items = [] for enum_results in _get_child_nodes(doc, 'EnumerationResults'): # path is something like Queues, Queue for child in _get_children_from_path(enum_results, resp_type, resp_type[:-1]): items.append(_fill_instance_element(child, item_type)) for name, value in vars(return_obj).items(): # queues, Queues, this is the list its self which we populated # above if name == resp_type.lower(): # the list its self. continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) setattr(return_obj, resp_type.lower(), items) return return_obj def _parse_simple_list(response, type, item_type, list_name): respbody = response.body res = type() res_items = [] doc = minidom.parseString(respbody) type_name = type.__name__ item_name = item_type.__name__ for item in _get_children_from_path(doc, type_name, item_name): res_items.append(_fill_instance_element(item, item_type)) setattr(res, list_name, res_items) return res def _parse_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_xml_text(response.body, return_type) def _parse_service_resources_response(response, return_type): ''' Parse the HTTPResponse's body and fill all the data into a class of return_type. ''' return _parse_response_body_from_service_resources_xml_text(response.body, return_type) def _fill_data_to_return_object(node, return_obj): members = dict(vars(return_obj)) for name, value in members.items(): if isinstance(value, _list_of): setattr(return_obj, name, _fill_list_of(node, value.list_type, value.xml_element_name)) elif isinstance(value, _scalar_list_of): setattr(return_obj, name, _fill_scalar_list_of(node, value.list_type, _get_serialization_name(name), value.xml_element_name)) elif isinstance(value, _dict_of): setattr(return_obj, name, _fill_dict_of(node, _get_serialization_name(name), value.pair_xml_element_name, value.key_xml_element_name, value.value_xml_element_name)) elif isinstance(value, _xml_attribute): real_value = None if node.hasAttribute(value.xml_element_name): real_value = node.getAttribute(value.xml_element_name) if real_value is not None: setattr(return_obj, name, real_value) elif isinstance(value, WindowsAzureData): setattr(return_obj, name, _fill_instance_child(node, name, value.__class__)) elif isinstance(value, dict): setattr(return_obj, name, _fill_dict(node, _get_serialization_name(name))) elif isinstance(value, _Base64String): value = _fill_data_minidom(node, name, '') if value is not None: value = _decode_base64_to_text(value) # always set the attribute, so we don't end up returning an object # with type _Base64String setattr(return_obj, name, value) else: value = _fill_data_minidom(node, name, value) if value is not None: setattr(return_obj, name, value) def _parse_response_body_from_xml_node(node, return_type): ''' parse the xml and fill all the data into a class of return_type ''' return_obj = return_type() _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = return_type() xml_name = return_type._xml_name if hasattr(return_type, '_xml_name') else return_type.__name__ for node in _get_child_nodes(doc, xml_name): _fill_data_to_return_object(node, return_obj) return return_obj def _parse_response_body_from_service_resources_xml_text(respbody, return_type): ''' parse the xml and fill all the data into a class of return_type ''' doc = minidom.parseString(respbody) return_obj = _list_of(return_type) for node in _get_children_from_path(doc, "ServiceResources", "ServiceResource"): local_obj = return_type() _fill_data_to_return_object(node, local_obj) return_obj.append(local_obj) return return_obj class _dict_of(dict): """a dict which carries with it the xml element names for key,val. Used for deserializaion and construction of the lists""" def __init__(self, pair_xml_element_name, key_xml_element_name, value_xml_element_name): self.pair_xml_element_name = pair_xml_element_name self.key_xml_element_name = key_xml_element_name self.value_xml_element_name = value_xml_element_name super(_dict_of, self).__init__() class _list_of(list): """a list which carries with it the type that's expected to go in it. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name=None): self.list_type = list_type if xml_element_name is None: self.xml_element_name = list_type.__name__ else: self.xml_element_name = xml_element_name super(_list_of, self).__init__() class _scalar_list_of(list): """a list of scalar types which carries with it the type that's expected to go in it along with its xml element name. Used for deserializaion and construction of the lists""" def __init__(self, list_type, xml_element_name): self.list_type = list_type self.xml_element_name = xml_element_name super(_scalar_list_of, self).__init__() class _xml_attribute: """a accessor to XML attributes expected to go in it along with its xml element name. Used for deserialization and construction""" def __init__(self, xml_element_name): self.xml_element_name = xml_element_name def _update_request_uri_query_local_storage(request, use_local_storage): ''' create correct uri and query for the request ''' uri, query = _update_request_uri_query(request) if use_local_storage: return '/' + DEV_ACCOUNT_NAME + uri, query return uri, query def _update_request_uri_query(request): '''pulls the query string out of the URI and moves it into the query portion of the request object. If there are already query parameters on the request the parameters in the URI will appear after the existing parameters''' if '?' in request.path: request.path, _, query_string = request.path.partition('?') if query_string: query_params = query_string.split('&') for query in query_params: if '=' in query: name, _, value = query.partition('=') request.query.append((name, value)) request.path = url_quote(request.path, '/()$=\',') # add encoded queries to request.path. if request.query: request.path += '?' for name, value in request.query: if value is not None: request.path += name + '=' + url_quote(value, '/()$=\',') + '&' request.path = request.path[:-1] return request.path, request.query def _dont_fail_on_exist(error): ''' don't throw exception if the resource exists. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureConflictError): return False else: raise error def _dont_fail_not_exist(error): ''' don't throw exception if the resource doesn't exist. This is called by create_* APIs with fail_on_exist=False''' if isinstance(error, WindowsAzureMissingResourceError): return False else: raise error def _general_error_handler(http_error): ''' Simple error handler for azure.''' if http_error.status == 409: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(str(http_error))) elif http_error.status == 404: raise WindowsAzureMissingResourceError( _ERROR_NOT_FOUND.format(str(http_error))) else: if http_error.respbody is not None: raise WindowsAzureError( _ERROR_UNKNOWN.format(str(http_error)) + '\n' + \ http_error.respbody.decode('utf-8')) else: raise WindowsAzureError(_ERROR_UNKNOWN.format(str(http_error))) def _parse_response_for_dict(response): ''' Extracts name-values from response header. Filter out the standard http headers.''' if response is None: return None http_headers = ['server', 'date', 'location', 'host', 'via', 'proxy-connection', 'connection'] return_dict = HeaderDict() if response.headers: for name, value in response.headers: if not name.lower() in http_headers: return_dict[name] = value return return_dict def _parse_response_for_dict_prefix(response, prefixes): ''' Extracts name-values for names starting with prefix from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): for prefix_value in prefixes: if name.lower().startswith(prefix_value.lower()): return_dict[name] = value break return return_dict else: return None def _parse_response_for_dict_filter(response, filter): ''' Extracts name-values for names in filter from response header. Filter out the standard http headers.''' if response is None: return None return_dict = {} orig_dict = _parse_response_for_dict(response) if orig_dict: for name, value in orig_dict.items(): if name.lower() in filter: return_dict[name] = value return return_dict else: return None def _sign_string(key, string_to_sign, key_is_base64=True): if key_is_base64: key = _decode_base64_to_bytes(key) else: if isinstance(key, _unicode_type): key = key.encode('utf-8') if isinstance(string_to_sign, _unicode_type): string_to_sign = string_to_sign.encode('utf-8') signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = _encode_base64(digest) return encoded_digest ================================================ FILE: OSPatching/azure/azure.pyproj ================================================ <?xml version="1.0" encoding="utf-8"?> <Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0"> <PropertyGroup> <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> <SchemaVersion>2.0</SchemaVersion> <ProjectGuid>{25b2c65a-0553-4452-8907-8b5b17544e68}</ProjectGuid> <ProjectHome> </ProjectHome> <StartupFile>storage\blobservice.py</StartupFile> <SearchPath>..</SearchPath> <WorkingDirectory>.</WorkingDirectory> <OutputPath>.</OutputPath> <Name>azure</Name> <RootNamespace>azure</RootNamespace> <IsWindowsApplication>False</IsWindowsApplication> <LaunchProvider>Standard Python launcher</LaunchProvider> <CommandLineArguments /> <InterpreterPath /> <InterpreterArguments /> <InterpreterId>{9a7a9026-48c1-4688-9d5d-e5699d47d074}</InterpreterId> <InterpreterVersion>3.4</InterpreterVersion> <SccProjectName>SAK</SccProjectName> <SccProvider>SAK</SccProvider> <SccAuxPath>SAK</SccAuxPath> <SccLocalPath>SAK</SccLocalPath> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Debug' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Release' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <ItemGroup> <Compile Include="http\batchclient.py" /> <Compile Include="http\httpclient.py" /> <Compile Include="http\winhttp.py" /> <Compile Include="http\__init__.py" /> <Compile Include="servicemanagement\servicebusmanagementservice.py" /> <Compile Include="servicemanagement\servicemanagementclient.py" /> <Compile Include="servicemanagement\servicemanagementservice.py" /> <Compile Include="servicemanagement\sqldatabasemanagementservice.py" /> <Compile Include="servicemanagement\websitemanagementservice.py" /> <Compile Include="servicemanagement\__init__.py" /> <Compile Include="servicebus\servicebusservice.py" /> <Compile Include="storage\blobservice.py" /> <Compile Include="storage\queueservice.py" /> <Compile Include="storage\cloudstorageaccount.py" /> <Compile Include="storage\tableservice.py" /> <Compile Include="storage\sharedaccesssignature.py" /> <Compile Include="__init__.py" /> <Compile Include="servicebus\__init__.py" /> <Compile Include="storage\storageclient.py" /> <Compile Include="storage\__init__.py" /> </ItemGroup> <ItemGroup> <Folder Include="http" /> <Folder Include="servicemanagement" /> <Folder Include="servicebus\" /> <Folder Include="storage" /> </ItemGroup> <ItemGroup> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.6" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.7" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.3" /> <InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.4" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\2.7" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.3" /> <InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.4" /> </ItemGroup> <PropertyGroup> <VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion> <VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath> <PtvsTargetsFile>$(VSToolsPath)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile> </PropertyGroup> <Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" /> <Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" /> </Project> ================================================ FILE: OSPatching/azure/http/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- HTTP_RESPONSE_NO_CONTENT = 204 class HTTPError(Exception): ''' HTTP Exception when response status code >= 300 ''' def __init__(self, status, message, respheader, respbody): '''Creates a new HTTPError with the specified status, message, response headers and body''' self.status = status self.respheader = respheader self.respbody = respbody Exception.__init__(self, message) class HTTPResponse(object): """Represents a response from an HTTP request. An HTTPResponse has the following attributes: status: the status code of the response message: the message headers: the returned headers, as a list of (name, value) pairs body: the body of the response """ def __init__(self, status, message, headers, body): self.status = status self.message = message self.headers = headers self.body = body class HTTPRequest(object): '''Represents an HTTP Request. An HTTP Request consists of the following attributes: host: the host name to connect to method: the method to use to connect (string such as GET, POST, PUT, etc.) path: the uri fragment query: query parameters specified as a list of (name, value) pairs headers: header values specified as (name, value) pairs body: the body of the request. protocol_override: specify to use this protocol instead of the global one stored in _HTTPClient. ''' def __init__(self): self.host = '' self.method = '' self.path = '' self.query = [] # list of (name, value) self.headers = [] # list of (header name, header value) self.body = '' self.protocol_override = None ================================================ FILE: OSPatching/azure/http/batchclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import uuid from azure import ( _update_request_uri_query, WindowsAzureError, WindowsAzureBatchOperationError, _get_children_from_path, url_unquote, _ERROR_CANNOT_FIND_PARTITION_KEY, _ERROR_CANNOT_FIND_ROW_KEY, _ERROR_INCORRECT_TABLE_IN_BATCH, _ERROR_INCORRECT_PARTITION_KEY_IN_BATCH, _ERROR_DUPLICATE_ROW_KEY_IN_BATCH, _ERROR_BATCH_COMMIT_FAIL, ) from azure.http import HTTPError, HTTPRequest, HTTPResponse from azure.http.httpclient import _HTTPClient from azure.storage import ( _update_storage_table_header, METADATA_NS, _sign_storage_table_request, ) from xml.dom import minidom _DATASERVICES_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices' if sys.version_info < (3,): def _new_boundary(): return str(uuid.uuid1()) else: def _new_boundary(): return str(uuid.uuid1()).encode('utf-8') class _BatchClient(_HTTPClient): ''' This is the class that is used for batch operation for storage table service. It only supports one changeset. ''' def __init__(self, service_instance, account_key, account_name, protocol='http'): _HTTPClient.__init__(self, service_instance, account_name=account_name, account_key=account_key, protocol=protocol) self.is_batch = False self.batch_requests = [] self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] def get_request_table(self, request): ''' Extracts table name from request.uri. The request.uri has either "/mytable(...)" or "/mytable" format. request: the request to insert, update or delete entity ''' if '(' in request.path: pos = request.path.find('(') return request.path[1:pos] else: return request.path[1:] def get_request_partition_key(self, request): ''' Extracts PartitionKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the PartitionKey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) part_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'PartitionKey')) if not part_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return part_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('PartitionKey=\'') pos2 = uri.find('\',', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY) return uri[pos1 + len('PartitionKey=\''):pos2] def get_request_row_key(self, request): ''' Extracts RowKey from request.body if it is a POST request or from request.path if it is not a POST request. Only insert operation request is a POST request and the Rowkey is in the request body. request: the request to insert, update or delete entity ''' if request.method == 'POST': doc = minidom.parseString(request.body) row_key = _get_children_from_path( doc, 'entry', 'content', (METADATA_NS, 'properties'), (_DATASERVICES_NS, 'RowKey')) if not row_key: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) return row_key[0].firstChild.nodeValue else: uri = url_unquote(request.path) pos1 = uri.find('RowKey=\'') pos2 = uri.find('\')', pos1) if pos1 == -1 or pos2 == -1: raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY) row_key = uri[pos1 + len('RowKey=\''):pos2] return row_key def validate_request_table(self, request): ''' Validates that all requests have the same table name. Set the table name if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_table: if self.get_request_table(request) != self.batch_table: raise WindowsAzureError(_ERROR_INCORRECT_TABLE_IN_BATCH) else: self.batch_table = self.get_request_table(request) def validate_request_partition_key(self, request): ''' Validates that all requests have the same PartitiionKey. Set the PartitionKey if it is the first request for the batch operation. request: the request to insert, update or delete entity ''' if self.batch_partition_key: if self.get_request_partition_key(request) != \ self.batch_partition_key: raise WindowsAzureError(_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH) else: self.batch_partition_key = self.get_request_partition_key(request) def validate_request_row_key(self, request): ''' Validates that all requests have the different RowKey and adds RowKey to existing RowKey list. request: the request to insert, update or delete entity ''' if self.batch_row_keys: if self.get_request_row_key(request) in self.batch_row_keys: raise WindowsAzureError(_ERROR_DUPLICATE_ROW_KEY_IN_BATCH) else: self.batch_row_keys.append(self.get_request_row_key(request)) def begin_batch(self): ''' Starts the batch operation. Intializes the batch variables is_batch: batch operation flag. batch_table: the table name of the batch operation batch_partition_key: the PartitionKey of the batch requests. batch_row_keys: the RowKey list of adding requests. batch_requests: the list of the requests. ''' self.is_batch = True self.batch_table = '' self.batch_partition_key = '' self.batch_row_keys = [] self.batch_requests = [] def insert_request_to_batch(self, request): ''' Adds request to batch operation. request: the request to insert, update or delete entity ''' self.validate_request_table(request) self.validate_request_partition_key(request) self.validate_request_row_key(request) self.batch_requests.append(request) def commit_batch(self): ''' Resets batch flag and commits the batch requests. ''' if self.is_batch: self.is_batch = False self.commit_batch_requests() def commit_batch_requests(self): ''' Commits the batch requests. ''' batch_boundary = b'batch_' + _new_boundary() changeset_boundary = b'changeset_' + _new_boundary() # Commits batch only the requests list is not empty. if self.batch_requests: request = HTTPRequest() request.method = 'POST' request.host = self.batch_requests[0].host request.path = '/$batch' request.headers = [ ('Content-Type', 'multipart/mixed; boundary=' + \ batch_boundary.decode('utf-8')), ('Accept', 'application/atom+xml,application/xml'), ('Accept-Charset', 'UTF-8')] request.body = b'--' + batch_boundary + b'\n' request.body += b'Content-Type: multipart/mixed; boundary=' request.body += changeset_boundary + b'\n\n' content_id = 1 # Adds each request body to the POST data. for batch_request in self.batch_requests: request.body += b'--' + changeset_boundary + b'\n' request.body += b'Content-Type: application/http\n' request.body += b'Content-Transfer-Encoding: binary\n\n' request.body += batch_request.method.encode('utf-8') request.body += b' http://' request.body += batch_request.host.encode('utf-8') request.body += batch_request.path.encode('utf-8') request.body += b' HTTP/1.1\n' request.body += b'Content-ID: ' request.body += str(content_id).encode('utf-8') + b'\n' content_id += 1 # Add different headers for different type requests. if not batch_request.method == 'DELETE': request.body += \ b'Content-Type: application/atom+xml;type=entry\n' for name, value in batch_request.headers: if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n' break request.body += b'Content-Length: ' request.body += str(len(batch_request.body)).encode('utf-8') request.body += b'\n\n' request.body += batch_request.body + b'\n' else: for name, value in batch_request.headers: # If-Match should be already included in # batch_request.headers, but in case it is missing, # just add it. if name == 'If-Match': request.body += name.encode('utf-8') + b': ' request.body += value.encode('utf-8') + b'\n\n' break else: request.body += b'If-Match: *\n\n' request.body += b'--' + changeset_boundary + b'--' + b'\n' request.body += b'--' + batch_boundary + b'--' request.path, request.query = _update_request_uri_query(request) request.headers = _update_storage_table_header(request) auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) # Submit the whole request as batch request. response = self.perform_request(request) if response.status >= 300: raise HTTPError(response.status, _ERROR_BATCH_COMMIT_FAIL, self.respheader, response.body) # http://www.odata.org/documentation/odata-version-2-0/batch-processing/ # The body of a ChangeSet response is either a response for all the # successfully processed change request within the ChangeSet, # formatted exactly as it would have appeared outside of a batch, # or a single response indicating a failure of the entire ChangeSet. responses = self._parse_batch_response(response.body) if responses and responses[0].status >= 300: self._report_batch_error(responses[0]) def cancel_batch(self): ''' Resets the batch flag. ''' self.is_batch = False def _parse_batch_response(self, body): parts = body.split(b'--changesetresponse_') responses = [] for part in parts: httpLocation = part.find(b'HTTP/') if httpLocation > 0: response = self._parse_batch_response_part(part[httpLocation:]) responses.append(response) return responses def _parse_batch_response_part(self, part): lines = part.splitlines(); # First line is the HTTP status/reason status, _, reason = lines[0].partition(b' ')[2].partition(b' ') # Followed by headers and body headers = [] body = b'' isBody = False for line in lines[1:]: if line == b'' and not isBody: isBody = True elif isBody: body += line else: headerName, _, headerVal = line.partition(b':') headers.append((headerName.lower(), headerVal)) return HTTPResponse(int(status), reason.strip(), headers, body) def _report_batch_error(self, response): xml = response.body.decode('utf-8') doc = minidom.parseString(xml) n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'code') code = n[0].firstChild.nodeValue if n and n[0].firstChild else '' n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'message') message = n[0].firstChild.nodeValue if n and n[0].firstChild else xml raise WindowsAzureBatchOperationError(message, code) ================================================ FILE: OSPatching/azure/http/httpclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import base64 import os import sys if sys.version_info < (3,): from httplib import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urlparse import urlparse else: from http.client import ( HTTPSConnection, HTTPConnection, HTTP_PORT, HTTPS_PORT, ) from urllib.parse import urlparse from azure.http import HTTPError, HTTPResponse from azure import _USER_AGENT_STRING, _update_request_uri_query class _HTTPClient(object): ''' Takes the request and sends it to cloud service and returns the response. ''' def __init__(self, service_instance, cert_file=None, account_name=None, account_key=None, protocol='https'): ''' service_instance: service client instance. cert_file: certificate file name/location. This is only used in hosted service management. account_name: the storage account. account_key: the storage account access key. ''' self.service_instance = service_instance self.status = None self.respheader = None self.message = None self.cert_file = cert_file self.account_name = account_name self.account_key = account_key self.protocol = protocol self.proxy_host = None self.proxy_port = None self.proxy_user = None self.proxy_password = None self.use_httplib = self.should_use_httplib() def should_use_httplib(self): if sys.platform.lower().startswith('win') and self.cert_file: # On Windows, auto-detect between Windows Store Certificate # (winhttp) and OpenSSL .pem certificate file (httplib). # # We used to only support certificates installed in the Windows # Certificate Store. # cert_file example: CURRENT_USER\my\CertificateName # # We now support using an OpenSSL .pem certificate file, # for a consistent experience across all platforms. # cert_file example: account\certificate.pem # # When using OpenSSL .pem certificate file on Windows, make sure # you are on CPython 2.7.4 or later. # If it's not an existing file on disk, then treat it as a path in # the Windows Certificate Store, which means we can't use httplib. if not os.path.isfile(self.cert_file): return False return True def set_proxy(self, host, port, user, password): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self.proxy_host = host self.proxy_port = port self.proxy_user = user self.proxy_password = password def get_uri(self, request): ''' Return the target uri for the request.''' protocol = request.protocol_override \ if request.protocol_override else self.protocol port = HTTP_PORT if protocol == 'http' else HTTPS_PORT return protocol + '://' + request.host + ':' + str(port) + request.path def get_connection(self, request): ''' Create connection for the request. ''' protocol = request.protocol_override \ if request.protocol_override else self.protocol target_host = request.host target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT if not self.use_httplib: import azure.http.winhttp connection = azure.http.winhttp._HTTPConnection( target_host, cert_file=self.cert_file, protocol=protocol) proxy_host = self.proxy_host proxy_port = self.proxy_port else: if ':' in target_host: target_host, _, target_port = target_host.rpartition(':') if self.proxy_host: proxy_host = target_host proxy_port = target_port host = self.proxy_host port = self.proxy_port else: host = target_host port = target_port if protocol == 'http': connection = HTTPConnection(host, int(port)) else: connection = HTTPSConnection( host, int(port), cert_file=self.cert_file) if self.proxy_host: headers = None if self.proxy_user and self.proxy_password: auth = base64.encodestring( "{0}:{1}".format(self.proxy_user, self.proxy_password)) headers = {'Proxy-Authorization': 'Basic {0}'.format(auth)} connection.set_tunnel(proxy_host, int(proxy_port), headers) return connection def send_request_headers(self, connection, request_headers): if self.use_httplib: if self.proxy_host: for i in connection._buffer: if i.startswith("Host: "): connection._buffer.remove(i) connection.putheader( 'Host', "{0}:{1}".format(connection._tunnel_host, connection._tunnel_port)) for name, value in request_headers: if value: connection.putheader(name, value) connection.putheader('User-Agent', _USER_AGENT_STRING) connection.endheaders() def send_request_body(self, connection, request_body): if request_body: assert isinstance(request_body, bytes) connection.send(request_body) elif (not isinstance(connection, HTTPSConnection) and not isinstance(connection, HTTPConnection)): connection.send(None) def perform_request(self, request): ''' Sends request to cloud service server and return the response. ''' connection = self.get_connection(request) try: connection.putrequest(request.method, request.path) if not self.use_httplib: if self.proxy_host and self.proxy_user: connection.set_proxy_credentials( self.proxy_user, self.proxy_password) self.send_request_headers(connection, request.headers) self.send_request_body(connection, request.body) resp = connection.getresponse() self.status = int(resp.status) self.message = resp.reason self.respheader = headers = resp.getheaders() # for consistency across platforms, make header names lowercase for i, value in enumerate(headers): headers[i] = (value[0].lower(), value[1]) respbody = None if resp.length is None: respbody = resp.read() elif resp.length > 0: respbody = resp.read(resp.length) response = HTTPResponse( int(resp.status), resp.reason, headers, respbody) if self.status == 307: new_url = urlparse(dict(headers)['location']) request.host = new_url.hostname request.path = new_url.path request.path, request.query = _update_request_uri_query(request) return self.perform_request(request) if self.status >= 300: raise HTTPError(self.status, self.message, self.respheader, respbody) return response finally: connection.close() ================================================ FILE: OSPatching/azure/http/winhttp.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from ctypes import ( c_void_p, c_long, c_ulong, c_longlong, c_ulonglong, c_short, c_ushort, c_wchar_p, c_byte, byref, Structure, Union, POINTER, WINFUNCTYPE, HRESULT, oledll, WinDLL, ) import ctypes import sys if sys.version_info >= (3,): def unicode(text): return text #------------------------------------------------------------------------------ # Constants that are used in COM operations VT_EMPTY = 0 VT_NULL = 1 VT_I2 = 2 VT_I4 = 3 VT_BSTR = 8 VT_BOOL = 11 VT_I1 = 16 VT_UI1 = 17 VT_UI2 = 18 VT_UI4 = 19 VT_I8 = 20 VT_UI8 = 21 VT_ARRAY = 8192 HTTPREQUEST_PROXYSETTING_PROXY = 2 HTTPREQUEST_SETCREDENTIALS_FOR_PROXY = 1 HTTPREQUEST_PROXY_SETTING = c_long HTTPREQUEST_SETCREDENTIALS_FLAGS = c_long #------------------------------------------------------------------------------ # Com related APIs that are used. _ole32 = oledll.ole32 _oleaut32 = WinDLL('oleaut32') _CLSIDFromString = _ole32.CLSIDFromString _CoInitialize = _ole32.CoInitialize _CoInitialize.argtypes = [c_void_p] _CoCreateInstance = _ole32.CoCreateInstance _SysAllocString = _oleaut32.SysAllocString _SysAllocString.restype = c_void_p _SysAllocString.argtypes = [c_wchar_p] _SysFreeString = _oleaut32.SysFreeString _SysFreeString.argtypes = [c_void_p] # SAFEARRAY* # SafeArrayCreateVector(_In_ VARTYPE vt,_In_ LONG lLbound,_In_ ULONG # cElements); _SafeArrayCreateVector = _oleaut32.SafeArrayCreateVector _SafeArrayCreateVector.restype = c_void_p _SafeArrayCreateVector.argtypes = [c_ushort, c_long, c_ulong] # HRESULT # SafeArrayAccessData(_In_ SAFEARRAY *psa, _Out_ void **ppvData); _SafeArrayAccessData = _oleaut32.SafeArrayAccessData _SafeArrayAccessData.argtypes = [c_void_p, POINTER(c_void_p)] # HRESULT # SafeArrayUnaccessData(_In_ SAFEARRAY *psa); _SafeArrayUnaccessData = _oleaut32.SafeArrayUnaccessData _SafeArrayUnaccessData.argtypes = [c_void_p] # HRESULT # SafeArrayGetUBound(_In_ SAFEARRAY *psa, _In_ UINT nDim, _Out_ LONG # *plUbound); _SafeArrayGetUBound = _oleaut32.SafeArrayGetUBound _SafeArrayGetUBound.argtypes = [c_void_p, c_ulong, POINTER(c_long)] #------------------------------------------------------------------------------ class BSTR(c_wchar_p): ''' BSTR class in python. ''' def __init__(self, value): super(BSTR, self).__init__(_SysAllocString(value)) def __del__(self): _SysFreeString(self) class VARIANT(Structure): ''' VARIANT structure in python. Does not match the definition in MSDN exactly & it is only mapping the used fields. Field names are also slighty different. ''' class _tagData(Union): class _tagRecord(Structure): _fields_ = [('pvoid', c_void_p), ('precord', c_void_p)] _fields_ = [('llval', c_longlong), ('ullval', c_ulonglong), ('lval', c_long), ('ulval', c_ulong), ('ival', c_short), ('boolval', c_ushort), ('bstrval', BSTR), ('parray', c_void_p), ('record', _tagRecord)] _fields_ = [('vt', c_ushort), ('wReserved1', c_ushort), ('wReserved2', c_ushort), ('wReserved3', c_ushort), ('vdata', _tagData)] @staticmethod def create_empty(): variant = VARIANT() variant.vt = VT_EMPTY variant.vdata.llval = 0 return variant @staticmethod def create_safearray_from_str(text): variant = VARIANT() variant.vt = VT_ARRAY | VT_UI1 length = len(text) variant.vdata.parray = _SafeArrayCreateVector(VT_UI1, 0, length) pvdata = c_void_p() _SafeArrayAccessData(variant.vdata.parray, byref(pvdata)) ctypes.memmove(pvdata, text, length) _SafeArrayUnaccessData(variant.vdata.parray) return variant @staticmethod def create_bstr_from_str(text): variant = VARIANT() variant.vt = VT_BSTR variant.vdata.bstrval = BSTR(text) return variant @staticmethod def create_bool_false(): variant = VARIANT() variant.vt = VT_BOOL variant.vdata.boolval = 0 return variant def is_safearray_of_bytes(self): return self.vt == VT_ARRAY | VT_UI1 def str_from_safearray(self): assert self.vt == VT_ARRAY | VT_UI1 pvdata = c_void_p() count = c_long() _SafeArrayGetUBound(self.vdata.parray, 1, byref(count)) count = c_long(count.value + 1) _SafeArrayAccessData(self.vdata.parray, byref(pvdata)) text = ctypes.string_at(pvdata, count) _SafeArrayUnaccessData(self.vdata.parray) return text def __del__(self): _VariantClear(self) # HRESULT VariantClear(_Inout_ VARIANTARG *pvarg); _VariantClear = _oleaut32.VariantClear _VariantClear.argtypes = [POINTER(VARIANT)] class GUID(Structure): ''' GUID structure in python. ''' _fields_ = [("data1", c_ulong), ("data2", c_ushort), ("data3", c_ushort), ("data4", c_byte * 8)] def __init__(self, name=None): if name is not None: _CLSIDFromString(unicode(name), byref(self)) class _WinHttpRequest(c_void_p): ''' Maps the Com API to Python class functions. Not all methods in IWinHttpWebRequest are mapped - only the methods we use. ''' _AddRef = WINFUNCTYPE(c_long) \ (1, 'AddRef') _Release = WINFUNCTYPE(c_long) \ (2, 'Release') _SetProxy = WINFUNCTYPE(HRESULT, HTTPREQUEST_PROXY_SETTING, VARIANT, VARIANT) \ (7, 'SetProxy') _SetCredentials = WINFUNCTYPE(HRESULT, BSTR, BSTR, HTTPREQUEST_SETCREDENTIALS_FLAGS) \ (8, 'SetCredentials') _Open = WINFUNCTYPE(HRESULT, BSTR, BSTR, VARIANT) \ (9, 'Open') _SetRequestHeader = WINFUNCTYPE(HRESULT, BSTR, BSTR) \ (10, 'SetRequestHeader') _GetResponseHeader = WINFUNCTYPE(HRESULT, BSTR, POINTER(c_void_p)) \ (11, 'GetResponseHeader') _GetAllResponseHeaders = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (12, 'GetAllResponseHeaders') _Send = WINFUNCTYPE(HRESULT, VARIANT) \ (13, 'Send') _Status = WINFUNCTYPE(HRESULT, POINTER(c_long)) \ (14, 'Status') _StatusText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (15, 'StatusText') _ResponseText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \ (16, 'ResponseText') _ResponseBody = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (17, 'ResponseBody') _ResponseStream = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \ (18, 'ResponseStream') _WaitForResponse = WINFUNCTYPE(HRESULT, VARIANT, POINTER(c_ushort)) \ (21, 'WaitForResponse') _Abort = WINFUNCTYPE(HRESULT) \ (22, 'Abort') _SetTimeouts = WINFUNCTYPE(HRESULT, c_long, c_long, c_long, c_long) \ (23, 'SetTimeouts') _SetClientCertificate = WINFUNCTYPE(HRESULT, BSTR) \ (24, 'SetClientCertificate') def open(self, method, url): ''' Opens the request. method: the request VERB 'GET', 'POST', etc. url: the url to connect ''' _WinHttpRequest._SetTimeouts(self, 0, 65000, 65000, 65000) flag = VARIANT.create_bool_false() _method = BSTR(method) _url = BSTR(url) _WinHttpRequest._Open(self, _method, _url, flag) def set_request_header(self, name, value): ''' Sets the request header. ''' _name = BSTR(name) _value = BSTR(value) _WinHttpRequest._SetRequestHeader(self, _name, _value) def get_all_response_headers(self): ''' Gets back all response headers. ''' bstr_headers = c_void_p() _WinHttpRequest._GetAllResponseHeaders(self, byref(bstr_headers)) bstr_headers = ctypes.cast(bstr_headers, c_wchar_p) headers = bstr_headers.value _SysFreeString(bstr_headers) return headers def send(self, request=None): ''' Sends the request body. ''' # Sends VT_EMPTY if it is GET, HEAD request. if request is None: var_empty = VARIANT.create_empty() _WinHttpRequest._Send(self, var_empty) else: # Sends request body as SAFEArray. _request = VARIANT.create_safearray_from_str(request) _WinHttpRequest._Send(self, _request) def status(self): ''' Gets status of response. ''' status = c_long() _WinHttpRequest._Status(self, byref(status)) return int(status.value) def status_text(self): ''' Gets status text of response. ''' bstr_status_text = c_void_p() _WinHttpRequest._StatusText(self, byref(bstr_status_text)) bstr_status_text = ctypes.cast(bstr_status_text, c_wchar_p) status_text = bstr_status_text.value _SysFreeString(bstr_status_text) return status_text def response_body(self): ''' Gets response body as a SAFEARRAY and converts the SAFEARRAY to str. If it is an xml file, it always contains 3 characters before <?xml, so we remove them. ''' var_respbody = VARIANT() _WinHttpRequest._ResponseBody(self, byref(var_respbody)) if var_respbody.is_safearray_of_bytes(): respbody = var_respbody.str_from_safearray() if respbody[3:].startswith(b'<?xml') and\ respbody.startswith(b'\xef\xbb\xbf'): respbody = respbody[3:] return respbody else: return '' def set_client_certificate(self, certificate): '''Sets client certificate for the request. ''' _certificate = BSTR(certificate) _WinHttpRequest._SetClientCertificate(self, _certificate) def set_tunnel(self, host, port): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling.''' url = host if port: url = url + u':' + port var_host = VARIANT.create_bstr_from_str(url) var_empty = VARIANT.create_empty() _WinHttpRequest._SetProxy( self, HTTPREQUEST_PROXYSETTING_PROXY, var_host, var_empty) def set_proxy_credentials(self, user, password): _WinHttpRequest._SetCredentials( self, BSTR(user), BSTR(password), HTTPREQUEST_SETCREDENTIALS_FOR_PROXY) def __del__(self): if self.value is not None: _WinHttpRequest._Release(self) class _Response(object): ''' Response class corresponding to the response returned from httplib HTTPConnection. ''' def __init__(self, _status, _status_text, _length, _headers, _respbody): self.status = _status self.reason = _status_text self.length = _length self.headers = _headers self.respbody = _respbody def getheaders(self): '''Returns response headers.''' return self.headers def read(self, _length): '''Returns resonse body. ''' return self.respbody[:_length] class _HTTPConnection(object): ''' Class corresponding to httplib HTTPConnection class. ''' def __init__(self, host, cert_file=None, key_file=None, protocol='http'): ''' initialize the IWinHttpWebRequest Com Object.''' self.host = unicode(host) self.cert_file = cert_file self._httprequest = _WinHttpRequest() self.protocol = protocol clsid = GUID('{2087C2F4-2CEF-4953-A8AB-66779B670495}') iid = GUID('{016FE2EC-B2C8-45F8-B23B-39E53A75396B}') _CoInitialize(None) _CoCreateInstance(byref(clsid), 0, 1, byref(iid), byref(self._httprequest)) def close(self): pass def set_tunnel(self, host, port=None, headers=None): ''' Sets up the host and the port for the HTTP CONNECT Tunnelling. ''' self._httprequest.set_tunnel(unicode(host), unicode(str(port))) def set_proxy_credentials(self, user, password): self._httprequest.set_proxy_credentials( unicode(user), unicode(password)) def putrequest(self, method, uri): ''' Connects to host and sends the request. ''' protocol = unicode(self.protocol + '://') url = protocol + self.host + unicode(uri) self._httprequest.open(unicode(method), url) # sets certificate for the connection if cert_file is set. if self.cert_file is not None: self._httprequest.set_client_certificate(unicode(self.cert_file)) def putheader(self, name, value): ''' Sends the headers of request. ''' if sys.version_info < (3,): name = str(name).decode('utf-8') value = str(value).decode('utf-8') self._httprequest.set_request_header(name, value) def endheaders(self): ''' No operation. Exists only to provide the same interface of httplib HTTPConnection.''' pass def send(self, request_body): ''' Sends request body. ''' if not request_body: self._httprequest.send() else: self._httprequest.send(request_body) def getresponse(self): ''' Gets the response and generates the _Response object''' status = self._httprequest.status() status_text = self._httprequest.status_text() resp_headers = self._httprequest.get_all_response_headers() fixed_headers = [] for resp_header in resp_headers.split('\n'): if (resp_header.startswith('\t') or\ resp_header.startswith(' ')) and fixed_headers: # append to previous header fixed_headers[-1] += resp_header else: fixed_headers.append(resp_header) headers = [] for resp_header in fixed_headers: if ':' in resp_header: pos = resp_header.find(':') headers.append( (resp_header[:pos].lower(), resp_header[pos + 1:].strip())) body = self._httprequest.response_body() length = len(body) return _Response(status, status_text, length, headers, body) ================================================ FILE: OSPatching/azure/servicebus/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import ast import json import sys from datetime import datetime from xml.dom import minidom from azure import ( WindowsAzureData, WindowsAzureError, xml_escape, _create_entry, _general_error_handler, _get_entry_properties, _get_child_nodes, _get_children_from_path, _get_first_child_node_value, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE, _ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK, _ERROR_QUEUE_NOT_FOUND, _ERROR_TOPIC_NOT_FOUND, ) from azure.http import HTTPError # default rule name for subscription DEFAULT_RULE_NAME = '$Default' #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_SERVICEBUS_NAMESPACE = 'AZURE_SERVICEBUS_NAMESPACE' AZURE_SERVICEBUS_ACCESS_KEY = 'AZURE_SERVICEBUS_ACCESS_KEY' AZURE_SERVICEBUS_ISSUER = 'AZURE_SERVICEBUS_ISSUER' # namespace used for converting rules to objects XML_SCHEMA_NAMESPACE = 'http://www.w3.org/2001/XMLSchema-instance' class Queue(WindowsAzureData): ''' Queue class corresponding to Queue Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780773''' def __init__(self, lock_duration=None, max_size_in_megabytes=None, requires_duplicate_detection=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, duplicate_detection_history_time_window=None, max_delivery_count=None, enable_batched_operations=None, size_in_bytes=None, message_count=None): self.lock_duration = lock_duration self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.max_delivery_count = max_delivery_count self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes self.message_count = message_count class Topic(WindowsAzureData): ''' Topic class corresponding to Topic Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780749. ''' def __init__(self, default_message_time_to_live=None, max_size_in_megabytes=None, requires_duplicate_detection=None, duplicate_detection_history_time_window=None, enable_batched_operations=None, size_in_bytes=None): self.default_message_time_to_live = default_message_time_to_live self.max_size_in_megabytes = max_size_in_megabytes self.requires_duplicate_detection = requires_duplicate_detection self.duplicate_detection_history_time_window = \ duplicate_detection_history_time_window self.enable_batched_operations = enable_batched_operations self.size_in_bytes = size_in_bytes @property def max_size_in_mega_bytes(self): import warnings warnings.warn( 'This attribute has been changed to max_size_in_megabytes.') return self.max_size_in_megabytes @max_size_in_mega_bytes.setter def max_size_in_mega_bytes(self, value): self.max_size_in_megabytes = value class Subscription(WindowsAzureData): ''' Subscription class corresponding to Subscription Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780763. ''' def __init__(self, lock_duration=None, requires_session=None, default_message_time_to_live=None, dead_lettering_on_message_expiration=None, dead_lettering_on_filter_evaluation_exceptions=None, enable_batched_operations=None, max_delivery_count=None, message_count=None): self.lock_duration = lock_duration self.requires_session = requires_session self.default_message_time_to_live = default_message_time_to_live self.dead_lettering_on_message_expiration = \ dead_lettering_on_message_expiration self.dead_lettering_on_filter_evaluation_exceptions = \ dead_lettering_on_filter_evaluation_exceptions self.enable_batched_operations = enable_batched_operations self.max_delivery_count = max_delivery_count self.message_count = message_count class Rule(WindowsAzureData): ''' Rule class corresponding to Rule Description: http://msdn.microsoft.com/en-us/library/windowsazure/hh780753. ''' def __init__(self, filter_type=None, filter_expression=None, action_type=None, action_expression=None): self.filter_type = filter_type self.filter_expression = filter_expression self.action_type = action_type self.action_expression = action_type class Message(WindowsAzureData): ''' Message class that used in send message/get mesage apis. ''' def __init__(self, body=None, service_bus_service=None, location=None, custom_properties=None, type='application/atom+xml;type=entry;charset=utf-8', broker_properties=None): self.body = body self.location = location self.broker_properties = broker_properties self.custom_properties = custom_properties self.type = type self.service_bus_service = service_bus_service self._topic_name = None self._subscription_name = None self._queue_name = None if not service_bus_service: return # if location is set, then extracts the queue name for queue message and # extracts the topic and subscriptions name if it is topic message. if location: if '/subscriptions/' in location: pos = location.find('/subscriptions/') pos1 = location.rfind('/', 0, pos - 1) self._topic_name = location[pos1 + 1:pos] pos += len('/subscriptions/') pos1 = location.find('/', pos) self._subscription_name = location[pos:pos1] elif '/messages/' in location: pos = location.find('/messages/') pos1 = location.rfind('/', 0, pos - 1) self._queue_name = location[pos1 + 1:pos] def delete(self): ''' Deletes itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.delete_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.delete_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_DELETE) def unlock(self): ''' Unlocks itself if find queue name or topic name and subscription name. ''' if self._queue_name: self.service_bus_service.unlock_queue_message( self._queue_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) elif self._topic_name and self._subscription_name: self.service_bus_service.unlock_subscription_message( self._topic_name, self._subscription_name, self.broker_properties['SequenceNumber'], self.broker_properties['LockToken']) else: raise WindowsAzureError(_ERROR_MESSAGE_NOT_PEEK_LOCKED_ON_UNLOCK) def add_headers(self, request): ''' add addtional headers to request for message request.''' # Adds custom properties if self.custom_properties: for name, value in self.custom_properties.items(): if sys.version_info < (3,) and isinstance(value, unicode): request.headers.append( (name, '"' + value.encode('utf-8') + '"')) elif isinstance(value, str): request.headers.append((name, '"' + str(value) + '"')) elif isinstance(value, datetime): request.headers.append( (name, '"' + value.strftime('%a, %d %b %Y %H:%M:%S GMT') + '"')) else: request.headers.append((name, str(value).lower())) # Adds content-type request.headers.append(('Content-Type', self.type)) # Adds BrokerProperties if self.broker_properties: request.headers.append( ('BrokerProperties', str(self.broker_properties))) return request.headers def _create_message(response, service_instance): ''' Create message from response. response: response from service bus cloud server. service_instance: the service bus client. ''' respbody = response.body custom_properties = {} broker_properties = None message_type = None message_location = None # gets all information from respheaders. for name, value in response.headers: if name.lower() == 'brokerproperties': broker_properties = json.loads(value) elif name.lower() == 'content-type': message_type = value elif name.lower() == 'location': message_location = value elif name.lower() not in ['content-type', 'brokerproperties', 'transfer-encoding', 'server', 'location', 'date']: if '"' in value: value = value[1:-1] try: custom_properties[name] = datetime.strptime( value, '%a, %d %b %Y %H:%M:%S GMT') except ValueError: custom_properties[name] = value else: # only int, float or boolean if value.lower() == 'true': custom_properties[name] = True elif value.lower() == 'false': custom_properties[name] = False # int('3.1') doesn't work so need to get float('3.14') first elif str(int(float(value))) == value: custom_properties[name] = int(value) else: custom_properties[name] = float(value) if message_type == None: message = Message( respbody, service_instance, message_location, custom_properties, 'application/atom+xml;type=entry;charset=utf-8', broker_properties) else: message = Message(respbody, service_instance, message_location, custom_properties, message_type, broker_properties) return message # convert functions def _convert_response_to_rule(response): return _convert_xml_to_rule(response.body) def _convert_xml_to_rule(xmlstr): ''' Converts response xml to rule object. The format of xml for rule: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Filter i:type="SqlFilterExpression"> <SqlExpression>MyProperty='XYZ'</SqlExpression> </Filter> <Action i:type="SqlFilterAction"> <SqlExpression>set MyProperty2 = 'ABC'</SqlExpression> </Action> </RuleDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) rule = Rule() for rule_desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RuleDescription'): for xml_filter in _get_child_nodes(rule_desc, 'Filter'): filter_type = xml_filter.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'filter_type', str(filter_type)) if xml_filter.childNodes: for expr in _get_child_nodes(xml_filter, 'SqlExpression'): setattr(rule, 'filter_expression', expr.firstChild.nodeValue) for xml_action in _get_child_nodes(rule_desc, 'Action'): action_type = xml_action.getAttributeNS( XML_SCHEMA_NAMESPACE, 'type') setattr(rule, 'action_type', str(action_type)) if xml_action.childNodes: action_expression = xml_action.childNodes[0].firstChild if action_expression: setattr(rule, 'action_expression', action_expression.nodeValue) # extract id, updated and name value from feed entry and set them of rule. for name, value in _get_entry_properties(xmlstr, True, '/rules').items(): setattr(rule, name, value) return rule def _convert_response_to_queue(response): return _convert_xml_to_queue(response.body) def _parse_bool(value): if value.lower() == 'true': return True return False def _convert_xml_to_queue(xmlstr): ''' Converts xml response to queue object. The format of xml response for queue: <QueueDescription xmlns=\"http://schemas.microsoft.com/netservices/2010/10/servicebus/connect\"> <MaxSizeInBytes>10000</MaxSizeInBytes> <DefaultMessageTimeToLive>PT5M</DefaultMessageTimeToLive> <LockDuration>PT2M</LockDuration> <RequiresGroupedReceives>False</RequiresGroupedReceives> <SupportsDuplicateDetection>False</SupportsDuplicateDetection> ... </QueueDescription> ''' xmldoc = minidom.parseString(xmlstr) queue = Queue() invalid_queue = True # get node for each attribute in Queue class, if nothing found then the # response is not valid xml for Queue. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'QueueDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: queue.lock_duration = node_value invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: queue.max_size_in_megabytes = int(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: queue.requires_duplicate_detection = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'RequiresSession') if node_value is not None: queue.requires_session = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: queue.default_message_time_to_live = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: queue.dead_lettering_on_message_expiration = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: queue.duplicate_detection_history_time_window = node_value invalid_queue = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: queue.enable_batched_operations = _parse_bool(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MaxDeliveryCount') if node_value is not None: queue.max_delivery_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'MessageCount') if node_value is not None: queue.message_count = int(node_value) invalid_queue = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: queue.size_in_bytes = int(node_value) invalid_queue = False if invalid_queue: raise WindowsAzureError(_ERROR_QUEUE_NOT_FOUND) # extract id, updated and name value from feed entry and set them of queue. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(queue, name, value) return queue def _convert_response_to_topic(response): return _convert_xml_to_topic(response.body) def _convert_xml_to_topic(xmlstr): '''Converts xml response to topic The xml format for topic: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <MaxSizeInMegabytes>1024</MaxSizeInMegabytes> <RequiresDuplicateDetection>false</RequiresDuplicateDetection> <DuplicateDetectionHistoryTimeWindow>P7D</DuplicateDetectionHistoryTimeWindow> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </TopicDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) topic = Topic() invalid_topic = True # get node for each attribute in Topic class, if nothing found then the # response is not valid xml for Topic. for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'TopicDescription'): invalid_topic = True node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: topic.default_message_time_to_live = node_value invalid_topic = False node_value = _get_first_child_node_value(desc, 'MaxSizeInMegabytes') if node_value is not None: topic.max_size_in_megabytes = int(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'RequiresDuplicateDetection') if node_value is not None: topic.requires_duplicate_detection = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value( desc, 'DuplicateDetectionHistoryTimeWindow') if node_value is not None: topic.duplicate_detection_history_time_window = node_value invalid_topic = False node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: topic.enable_batched_operations = _parse_bool(node_value) invalid_topic = False node_value = _get_first_child_node_value(desc, 'SizeInBytes') if node_value is not None: topic.size_in_bytes = int(node_value) invalid_topic = False if invalid_topic: raise WindowsAzureError(_ERROR_TOPIC_NOT_FOUND) # extract id, updated and name value from feed entry and set them of topic. for name, value in _get_entry_properties(xmlstr, True).items(): setattr(topic, name, value) return topic def _convert_response_to_subscription(response): return _convert_xml_to_subscription(response.body) def _convert_xml_to_subscription(xmlstr): '''Converts xml response to subscription The xml format for subscription: <entry xmlns='http://www.w3.org/2005/Atom'> <content type='application/xml'> <SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <LockDuration>PT5M</LockDuration> <RequiresSession>false</RequiresSession> <DefaultMessageTimeToLive>P10675199DT2H48M5.4775807S</DefaultMessageTimeToLive> <DeadLetteringOnMessageExpiration>false</DeadLetteringOnMessageExpiration> <DeadLetteringOnFilterEvaluationExceptions>true</DeadLetteringOnFilterEvaluationExceptions> </SubscriptionDescription> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) subscription = Subscription() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'SubscriptionDescription'): node_value = _get_first_child_node_value(desc, 'LockDuration') if node_value is not None: subscription.lock_duration = node_value node_value = _get_first_child_node_value( desc, 'RequiresSession') if node_value is not None: subscription.requires_session = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DefaultMessageTimeToLive') if node_value is not None: subscription.default_message_time_to_live = node_value node_value = _get_first_child_node_value( desc, 'DeadLetteringOnFilterEvaluationExceptions') if node_value is not None: subscription.dead_lettering_on_filter_evaluation_exceptions = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'DeadLetteringOnMessageExpiration') if node_value is not None: subscription.dead_lettering_on_message_expiration = \ _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'EnableBatchedOperations') if node_value is not None: subscription.enable_batched_operations = _parse_bool(node_value) node_value = _get_first_child_node_value( desc, 'MaxDeliveryCount') if node_value is not None: subscription.max_delivery_count = int(node_value) node_value = _get_first_child_node_value( desc, 'MessageCount') if node_value is not None: subscription.message_count = int(node_value) for name, value in _get_entry_properties(xmlstr, True, '/subscriptions').items(): setattr(subscription, name, value) return subscription def _convert_subscription_to_xml(subscription): ''' Converts a subscription object to xml to send. The order of each field of subscription in xml is very important so we can't simple call convert_class_to_xml. subscription: the subsciption object to be converted. ''' subscription_body = '<SubscriptionDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if subscription: if subscription.lock_duration is not None: subscription_body += ''.join( ['<LockDuration>', str(subscription.lock_duration), '</LockDuration>']) if subscription.requires_session is not None: subscription_body += ''.join( ['<RequiresSession>', str(subscription.requires_session).lower(), '</RequiresSession>']) if subscription.default_message_time_to_live is not None: subscription_body += ''.join( ['<DefaultMessageTimeToLive>', str(subscription.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if subscription.dead_lettering_on_message_expiration is not None: subscription_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(subscription.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if subscription.dead_lettering_on_filter_evaluation_exceptions is not None: subscription_body += ''.join( ['<DeadLetteringOnFilterEvaluationExceptions>', str(subscription.dead_lettering_on_filter_evaluation_exceptions).lower(), '</DeadLetteringOnFilterEvaluationExceptions>']) if subscription.enable_batched_operations is not None: subscription_body += ''.join( ['<EnableBatchedOperations>', str(subscription.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if subscription.max_delivery_count is not None: subscription_body += ''.join( ['<MaxDeliveryCount>', str(subscription.max_delivery_count), '</MaxDeliveryCount>']) if subscription.message_count is not None: subscription_body += ''.join( ['<MessageCount>', str(subscription.message_count), '</MessageCount>']) subscription_body += '</SubscriptionDescription>' return _create_entry(subscription_body) def _convert_rule_to_xml(rule): ''' Converts a rule object to xml to send. The order of each field of rule in xml is very important so we cann't simple call convert_class_to_xml. rule: the rule object to be converted. ''' rule_body = '<RuleDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if rule: if rule.filter_type: rule_body += ''.join( ['<Filter i:type="', xml_escape(rule.filter_type), '">']) if rule.filter_type == 'CorrelationFilter': rule_body += ''.join( ['<CorrelationId>', xml_escape(rule.filter_expression), '</CorrelationId>']) else: rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.filter_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Filter>' if rule.action_type: rule_body += ''.join( ['<Action i:type="', xml_escape(rule.action_type), '">']) if rule.action_type == 'SqlRuleAction': rule_body += ''.join( ['<SqlExpression>', xml_escape(rule.action_expression), '</SqlExpression>']) rule_body += '<CompatibilityLevel>20</CompatibilityLevel>' rule_body += '</Action>' rule_body += '</RuleDescription>' return _create_entry(rule_body) def _convert_topic_to_xml(topic): ''' Converts a topic object to xml to send. The order of each field of topic in xml is very important so we cann't simple call convert_class_to_xml. topic: the topic object to be converted. ''' topic_body = '<TopicDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if topic: if topic.default_message_time_to_live is not None: topic_body += ''.join( ['<DefaultMessageTimeToLive>', str(topic.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if topic.max_size_in_megabytes is not None: topic_body += ''.join( ['<MaxSizeInMegabytes>', str(topic.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if topic.requires_duplicate_detection is not None: topic_body += ''.join( ['<RequiresDuplicateDetection>', str(topic.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if topic.duplicate_detection_history_time_window is not None: topic_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(topic.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if topic.enable_batched_operations is not None: topic_body += ''.join( ['<EnableBatchedOperations>', str(topic.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if topic.size_in_bytes is not None: topic_body += ''.join( ['<SizeInBytes>', str(topic.size_in_bytes), '</SizeInBytes>']) topic_body += '</TopicDescription>' return _create_entry(topic_body) def _convert_queue_to_xml(queue): ''' Converts a queue object to xml to send. The order of each field of queue in xml is very important so we cann't simple call convert_class_to_xml. queue: the queue object to be converted. ''' queue_body = '<QueueDescription xmlns:i="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' if queue: if queue.lock_duration: queue_body += ''.join( ['<LockDuration>', str(queue.lock_duration), '</LockDuration>']) if queue.max_size_in_megabytes is not None: queue_body += ''.join( ['<MaxSizeInMegabytes>', str(queue.max_size_in_megabytes), '</MaxSizeInMegabytes>']) if queue.requires_duplicate_detection is not None: queue_body += ''.join( ['<RequiresDuplicateDetection>', str(queue.requires_duplicate_detection).lower(), '</RequiresDuplicateDetection>']) if queue.requires_session is not None: queue_body += ''.join( ['<RequiresSession>', str(queue.requires_session).lower(), '</RequiresSession>']) if queue.default_message_time_to_live is not None: queue_body += ''.join( ['<DefaultMessageTimeToLive>', str(queue.default_message_time_to_live), '</DefaultMessageTimeToLive>']) if queue.dead_lettering_on_message_expiration is not None: queue_body += ''.join( ['<DeadLetteringOnMessageExpiration>', str(queue.dead_lettering_on_message_expiration).lower(), '</DeadLetteringOnMessageExpiration>']) if queue.duplicate_detection_history_time_window is not None: queue_body += ''.join( ['<DuplicateDetectionHistoryTimeWindow>', str(queue.duplicate_detection_history_time_window), '</DuplicateDetectionHistoryTimeWindow>']) if queue.max_delivery_count is not None: queue_body += ''.join( ['<MaxDeliveryCount>', str(queue.max_delivery_count), '</MaxDeliveryCount>']) if queue.enable_batched_operations is not None: queue_body += ''.join( ['<EnableBatchedOperations>', str(queue.enable_batched_operations).lower(), '</EnableBatchedOperations>']) if queue.size_in_bytes is not None: queue_body += ''.join( ['<SizeInBytes>', str(queue.size_in_bytes), '</SizeInBytes>']) if queue.message_count is not None: queue_body += ''.join( ['<MessageCount>', str(queue.message_count), '</MessageCount>']) queue_body += '</QueueDescription>' return _create_entry(queue_body) def _service_bus_error_handler(http_error): ''' Simple error handler for service bus service. ''' return _general_error_handler(http_error) from azure.servicebus.servicebusservice import ServiceBusService ================================================ FILE: OSPatching/azure/servicebus/servicebusservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import datetime import os import time from azure import ( WindowsAzureError, SERVICE_BUS_HOST_BASE, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _sign_string, _str, _unicode_type, _update_request_uri_query, url_quote, url_unquote, _validate_not_none, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicebus import ( AZURE_SERVICEBUS_NAMESPACE, AZURE_SERVICEBUS_ACCESS_KEY, AZURE_SERVICEBUS_ISSUER, _convert_topic_to_xml, _convert_response_to_topic, _convert_queue_to_xml, _convert_response_to_queue, _convert_subscription_to_xml, _convert_response_to_subscription, _convert_rule_to_xml, _convert_response_to_rule, _convert_xml_to_queue, _convert_xml_to_topic, _convert_xml_to_subscription, _convert_xml_to_rule, _create_message, _service_bus_error_handler, ) class ServiceBusService(object): def __init__(self, service_namespace=None, account_key=None, issuer=None, x_ms_version='2011-06-01', host_base=SERVICE_BUS_HOST_BASE, shared_access_key_name=None, shared_access_key_value=None, authentication=None): ''' Initializes the service bus service for a namespace with the specified authentication settings (SAS or ACS). service_namespace: Service bus namespace, required for all operations. If None, the value is set to the AZURE_SERVICEBUS_NAMESPACE env variable. account_key: ACS authentication account key. If None, the value is set to the AZURE_SERVICEBUS_ACCESS_KEY env variable. Note that if both SAS and ACS settings are specified, SAS is used. issuer: ACS authentication issuer. If None, the value is set to the AZURE_SERVICEBUS_ISSUER env variable. Note that if both SAS and ACS settings are specified, SAS is used. x_ms_version: Unused. Kept for backwards compatibility. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. shared_access_key_name: SAS authentication key name. Note that if both SAS and ACS settings are specified, SAS is used. shared_access_key_value: SAS authentication key value. Note that if both SAS and ACS settings are specified, SAS is used. authentication: Instance of authentication class. If this is specified, then ACS and SAS parameters are ignored. ''' self.requestid = None self.service_namespace = service_namespace self.host_base = host_base if not self.service_namespace: self.service_namespace = os.environ.get(AZURE_SERVICEBUS_NAMESPACE) if not self.service_namespace: raise WindowsAzureError('You need to provide servicebus namespace') if authentication: self.authentication = authentication else: if not account_key: account_key = os.environ.get(AZURE_SERVICEBUS_ACCESS_KEY) if not issuer: issuer = os.environ.get(AZURE_SERVICEBUS_ISSUER) if shared_access_key_name and shared_access_key_value: self.authentication = ServiceBusSASAuthentication( shared_access_key_name, shared_access_key_value) elif account_key and issuer: self.authentication = ServiceBusWrapTokenAuthentication( account_key, issuer) else: raise WindowsAzureError( 'You need to provide servicebus access key and Issuer OR shared access key and value') self._httpclient = _HTTPClient(service_instance=self) self._filter = self._httpclient.perform_request # Backwards compatibility: # account_key and issuer used to be stored on the service class, they are # now stored on the authentication class. @property def account_key(self): return self.authentication.account_key @account_key.setter def account_key(self, value): self.authentication.account_key = value @property def issuer(self): return self.authentication.issuer @issuer.setter def issuer(self, value): self.authentication.issuer = value def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = ServiceBusService( service_namespace=self.service_namespace, authentication=self.authentication) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def create_queue(self, queue_name, queue=None, fail_on_exist=False): ''' Creates a new queue. Once created, this queue's resource manifest is immutable. queue_name: Name of the queue to create. queue: Queue object to create. fail_on_exist: Specify whether to throw an exception when the queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.body = _get_request_body(_convert_queue_to_xml(queue)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Deletes an existing queue. This operation will also remove all associated state including messages in the queue. queue_name: Name of the queue to delete. fail_not_exist: Specify whether to throw an exception if the queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue(self, queue_name): ''' Retrieves an existing queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_queue(response) def list_queues(self): ''' Enumerates the queues in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Queues' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_queue) def create_topic(self, topic_name, topic=None, fail_on_exist=False): ''' Creates a new topic. Once created, this topic resource manifest is immutable. topic_name: Name of the topic to create. topic: Topic object to create. fail_on_exist: Specify whether to throw an exception when the topic exists. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.body = _get_request_body(_convert_topic_to_xml(topic)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_topic(self, topic_name, fail_not_exist=False): ''' Deletes an existing topic. This operation will also remove all associated state including associated subscriptions. topic_name: Name of the topic to delete. fail_not_exist: Specify whether throw exception when topic doesn't exist. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_topic(self, topic_name): ''' Retrieves the description for the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_topic(response) def list_topics(self): ''' Retrieves the topics in the service namespace. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/$Resources/Topics' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_topic) def create_rule(self, topic_name, subscription_name, rule_name, rule=None, fail_on_exist=False): ''' Creates a new rule. Once created, this rule's resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. fail_on_exist: Specify whether to throw an exception when the rule exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.body = _get_request_body(_convert_rule_to_xml(rule)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_rule(self, topic_name, subscription_name, rule_name, fail_not_exist=False): ''' Deletes an existing rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule to delete. DEFAULT_RULE_NAME=$Default. Use DEFAULT_RULE_NAME to delete default rule for the subscription. fail_not_exist: Specify whether throw exception when rule doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_rule(self, topic_name, subscription_name, rule_name): ''' Retrieves the description for the specified rule. topic_name: Name of the topic. subscription_name: Name of the subscription. rule_name: Name of the rule. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('rule_name', rule_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + \ '/rules/' + _str(rule_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_rule(response) def list_rules(self, topic_name, subscription_name): ''' Retrieves the rules that exist under the specified subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/rules/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_rule) def create_subscription(self, topic_name, subscription_name, subscription=None, fail_on_exist=False): ''' Creates a new subscription. Once created, this subscription resource manifest is immutable. topic_name: Name of the topic. subscription_name: Name of the subscription. fail_on_exist: Specify whether throw exception when subscription exists. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.body = _get_request_body( _convert_subscription_to_xml(subscription)) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_subscription(self, topic_name, subscription_name, fail_not_exist=False): ''' Deletes an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription to delete. fail_not_exist: Specify whether to throw an exception when the subscription doesn't exist. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_subscription(self, topic_name, subscription_name): ''' Gets an existing subscription. topic_name: Name of the topic. subscription_name: Name of the subscription. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + _str(subscription_name) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_subscription(response) def list_subscriptions(self, topic_name): ''' Retrieves the subscriptions in the specified topic. topic_name: Name of the topic. ''' _validate_not_none('topic_name', topic_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/subscriptions/' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_subscription) def send_topic_message(self, topic_name, message=None): ''' Enqueues a message into the specified topic. The limit to the number of messages which may be present in the topic is governed by the message size in MaxTopicSizeInBytes. If this message causes the topic to exceed its quota, a quota exceeded error is returned and the message will be rejected. topic_name: Name of the topic. message: Message object containing message body and properties. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(topic_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only( 'message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' This operation is used to atomically retrieve and lock a message for processing. The message is guaranteed not to be delivered to other receivers during the lock duration period specified in buffer description. Once the lock expires, the message will be available to other receivers (on the same subscription only) during the lock duration period specified in the topic description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + \ _str(topic_name) + '/subscriptions/' + \ _str(subscription_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Unlock a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_subscription_message(self, topic_name, subscription_name, timeout='60'): ''' Read and delete a message from a subscription as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. topic_name: Name of the topic. subscription_name: Name of the subscription. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_subscription_message(self, topic_name, subscription_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the subscription. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. topic_name: Name of the topic. subscription_name: Name of the subscription. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('topic_name', topic_name) _validate_not_none('subscription_name', subscription_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(topic_name) + \ '/subscriptions/' + _str(subscription_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def send_queue_message(self, queue_name, message=None): ''' Sends a message into the specified queue. The limit to the number of messages which may be present in the topic is governed by the message size the MaxTopicSizeInMegaBytes. If this message will cause the queue to exceed its quota, a quota exceeded error is returned and the message will be rejected. queue_name: Name of the queue. message: Message object containing message body and properties. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message', message) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.headers = message.add_headers(request) request.body = _get_request_body_bytes_only('message.body', message.body) request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def peek_lock_queue_message(self, queue_name, timeout='60'): ''' Automically retrieves and locks a message from a queue for processing. The message is guaranteed not to be delivered to other receivers (on the same subscription only) during the lock duration period specified in the queue description. Once the lock expires, the message will be available to other receivers. In order to complete processing of the message, the receiver should issue a delete command with the lock ID received from this operation. To abandon processing of the message and unlock it for other receivers, an Unlock Message command should be issued, or the lock duration period can expire. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def unlock_queue_message(self, queue_name, sequence_number, lock_token): ''' Unlocks a message for processing by other receivers on a given subscription. This operation deletes the lock object, causing the message to be unlocked. A message must have first been locked by a receiver before this operation is called. queue_name: Name of the queue. sequence_number: The sequence number of the message to be unlocked as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def read_delete_queue_message(self, queue_name, timeout='60'): ''' Reads and deletes a message from a queue as an atomic operation. This operation should be used when a best-effort guarantee is sufficient for an application; that is, using this operation it is possible for messages to be lost if processing fails. queue_name: Name of the queue. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages/head' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) response = self._perform_request(request) return _create_message(response, self) def delete_queue_message(self, queue_name, sequence_number, lock_token): ''' Completes processing on a locked message and delete it from the queue. This operation should only be called after processing a previously locked message is successful to maintain At-Least-Once delivery assurances. queue_name: Name of the queue. sequence_number: The sequence number of the message to be deleted as returned in BrokerProperties['SequenceNumber'] by the Peek Message operation. lock_token: The ID of the lock as returned by the Peek Message operation in BrokerProperties['LockToken'] ''' _validate_not_none('queue_name', queue_name) _validate_not_none('sequence_number', sequence_number) _validate_not_none('lock_token', lock_token) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + \ '/messages/' + _str(sequence_number) + \ '/' + _str(lock_token) + '' request.path, request.query = _update_request_uri_query(request) request.headers = self._update_service_bus_header(request) self._perform_request(request) def receive_queue_message(self, queue_name, peek_lock=True, timeout=60): ''' Receive a message from a queue for processing. queue_name: Name of the queue. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_queue_message(queue_name, timeout) else: return self.read_delete_queue_message(queue_name, timeout) def receive_subscription_message(self, topic_name, subscription_name, peek_lock=True, timeout=60): ''' Receive a message from a subscription for processing. topic_name: Name of the topic. subscription_name: Name of the subscription. peek_lock: Optional. True to retrieve and lock the message. False to read and delete the message. Default is True (lock). timeout: Optional. The timeout parameter is expressed in seconds. ''' if peek_lock: return self.peek_lock_subscription_message(topic_name, subscription_name, timeout) else: return self.read_delete_subscription_message(topic_name, subscription_name, timeout) def _get_host(self): return self.service_namespace + self.host_base def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _service_bus_error_handler(ex) return resp def _update_service_bus_header(self, request): ''' Add additional headers for service bus. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) # Adds authorization header for authentication. self.authentication.sign_request(request, self._httpclient) return request.headers # Token cache for Authentication # Shared by the different instances of ServiceBusWrapTokenAuthentication _tokens = {} class ServiceBusWrapTokenAuthentication: def __init__(self, account_key, issuer): self.account_key = account_key self.issuer = issuer def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): ''' return the signed string with token. ''' return 'WRAP access_token="' + \ self._get_token(request.host, request.path, httpclient) + '"' def _token_is_expired(self, token): ''' Check if token expires or not. ''' time_pos_begin = token.find('ExpiresOn=') + len('ExpiresOn=') time_pos_end = token.find('&', time_pos_begin) token_expire_time = int(token[time_pos_begin:time_pos_end]) time_now = time.mktime(time.localtime()) # Adding 30 seconds so the token wouldn't be expired when we send the # token to server. return (token_expire_time - time_now) < 30 def _get_token(self, host, path, httpclient): ''' Returns token for the request. host: the service bus service request. path: the service bus service request. ''' wrap_scope = 'http://' + host + path + self.issuer + self.account_key # Check whether has unexpired cache, return cached token if it is still # usable. if wrap_scope in _tokens: token = _tokens[wrap_scope] if not self._token_is_expired(token): return token # get token from accessconstrol server request = HTTPRequest() request.protocol_override = 'https' request.host = host.replace('.servicebus.', '-sb.accesscontrol.') request.method = 'POST' request.path = '/WRAPv0.9' request.body = ('wrap_name=' + url_quote(self.issuer) + '&wrap_password=' + url_quote(self.account_key) + '&wrap_scope=' + url_quote('http://' + host + path)).encode('utf-8') request.headers.append(('Content-Length', str(len(request.body)))) resp = httpclient.perform_request(request) token = resp.body.decode('utf-8') token = url_unquote(token[token.find('=') + 1:token.rfind('&')]) _tokens[wrap_scope] = token return token class ServiceBusSASAuthentication: def __init__(self, key_name, key_value): self.key_name = key_name self.key_value = key_value def sign_request(self, request, httpclient): request.headers.append( ('Authorization', self._get_authorization(request, httpclient))) def _get_authorization(self, request, httpclient): uri = httpclient.get_uri(request) uri = url_quote(uri, '').lower() expiry = str(self._get_expiry()) to_sign = uri + '\n' + expiry signature = url_quote(_sign_string(self.key_value, to_sign, False), '') auth_format = 'SharedAccessSignature sig={0}&se={1}&skn={2}&sr={3}' auth = auth_format.format(signature, expiry, self.key_name, uri) return auth def _get_expiry(self): '''Returns the UTC datetime, in seconds since Epoch, when this signed request expires (5 minutes from now).''' return int(round(time.time() + 300)) ================================================ FILE: OSPatching/azure/servicemanagement/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from xml.dom import minidom from azure import ( WindowsAzureData, _Base64String, _create_entry, _dict_of, _encode_base64, _general_error_handler, _get_children_from_path, _get_first_child_node_value, _list_of, _scalar_list_of, _str, _xml_attribute, ) #----------------------------------------------------------------------------- # Constants for Azure app environment settings. AZURE_MANAGEMENT_CERTFILE = 'AZURE_MANAGEMENT_CERTFILE' AZURE_MANAGEMENT_SUBSCRIPTIONID = 'AZURE_MANAGEMENT_SUBSCRIPTIONID' # x-ms-version for service management. X_MS_VERSION = '2013-06-01' #----------------------------------------------------------------------------- # Data classes class StorageServices(WindowsAzureData): def __init__(self): self.storage_services = _list_of(StorageService) def __iter__(self): return iter(self.storage_services) def __len__(self): return len(self.storage_services) def __getitem__(self, index): return self.storage_services[index] class StorageService(WindowsAzureData): def __init__(self): self.url = '' self.service_name = '' self.storage_service_properties = StorageAccountProperties() self.storage_service_keys = StorageServiceKeys() self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') self.capabilities = _scalar_list_of(str, 'Capability') class StorageAccountProperties(WindowsAzureData): def __init__(self): self.description = u'' self.affinity_group = u'' self.location = u'' self.label = _Base64String() self.status = u'' self.endpoints = _scalar_list_of(str, 'Endpoint') self.geo_replication_enabled = False self.geo_primary_region = u'' self.status_of_primary = u'' self.geo_secondary_region = u'' self.status_of_secondary = u'' self.last_geo_failover_time = u'' self.creation_time = u'' class StorageServiceKeys(WindowsAzureData): def __init__(self): self.primary = u'' self.secondary = u'' class Locations(WindowsAzureData): def __init__(self): self.locations = _list_of(Location) def __iter__(self): return iter(self.locations) def __len__(self): return len(self.locations) def __getitem__(self, index): return self.locations[index] class Location(WindowsAzureData): def __init__(self): self.name = u'' self.display_name = u'' self.available_services = _scalar_list_of(str, 'AvailableService') class AffinityGroup(WindowsAzureData): def __init__(self): self.name = '' self.label = _Base64String() self.description = u'' self.location = u'' self.hosted_services = HostedServices() self.storage_services = StorageServices() self.capabilities = _scalar_list_of(str, 'Capability') class AffinityGroups(WindowsAzureData): def __init__(self): self.affinity_groups = _list_of(AffinityGroup) def __iter__(self): return iter(self.affinity_groups) def __len__(self): return len(self.affinity_groups) def __getitem__(self, index): return self.affinity_groups[index] class HostedServices(WindowsAzureData): def __init__(self): self.hosted_services = _list_of(HostedService) def __iter__(self): return iter(self.hosted_services) def __len__(self): return len(self.hosted_services) def __getitem__(self, index): return self.hosted_services[index] class HostedService(WindowsAzureData): def __init__(self): self.url = u'' self.service_name = u'' self.hosted_service_properties = HostedServiceProperties() self.deployments = Deployments() class HostedServiceProperties(WindowsAzureData): def __init__(self): self.description = u'' self.location = u'' self.affinity_group = u'' self.label = _Base64String() self.status = u'' self.date_created = u'' self.date_last_modified = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class VirtualNetworkSites(WindowsAzureData): def __init__(self): self.virtual_network_sites = _list_of(VirtualNetworkSite) def __iter__(self): return iter(self.virtual_network_sites) def __len__(self): return len(self.virtual_network_sites) def __getitem__(self, index): return self.virtual_network_sites[index] class VirtualNetworkSite(WindowsAzureData): def __init__(self): self.name = u'' self.id = u'' self.affinity_group = u'' self.subnets = Subnets() class Subnets(WindowsAzureData): def __init__(self): self.subnets = _list_of(Subnet) def __iter__(self): return iter(self.subnets) def __len__(self): return len(self.subnets) def __getitem__(self, index): return self.subnets[index] class Subnet(WindowsAzureData): def __init__(self): self.name = u'' self.address_prefix = u'' class Deployments(WindowsAzureData): def __init__(self): self.deployments = _list_of(Deployment) def __iter__(self): return iter(self.deployments) def __len__(self): return len(self.deployments) def __getitem__(self, index): return self.deployments[index] class Deployment(WindowsAzureData): def __init__(self): self.name = u'' self.deployment_slot = u'' self.private_id = u'' self.status = u'' self.label = _Base64String() self.url = u'' self.configuration = _Base64String() self.role_instance_list = RoleInstanceList() self.upgrade_status = UpgradeStatus() self.upgrade_domain_count = u'' self.role_list = RoleList() self.sdk_version = u'' self.input_endpoint_list = InputEndpoints() self.locked = False self.rollback_allowed = False self.persistent_vm_downtime_info = PersistentVMDowntimeInfo() self.created_time = u'' self.virtual_network_name = u'' self.last_modified_time = u'' self.extended_properties = _dict_of( 'ExtendedProperty', 'Name', 'Value') class RoleInstanceList(WindowsAzureData): def __init__(self): self.role_instances = _list_of(RoleInstance) def __iter__(self): return iter(self.role_instances) def __len__(self): return len(self.role_instances) def __getitem__(self, index): return self.role_instances[index] class RoleInstance(WindowsAzureData): def __init__(self): self.role_name = u'' self.instance_name = u'' self.instance_status = u'' self.instance_upgrade_domain = 0 self.instance_fault_domain = 0 self.instance_size = u'' self.instance_state_details = u'' self.instance_error_code = u'' self.ip_address = u'' self.instance_endpoints = InstanceEndpoints() self.power_state = u'' self.fqdn = u'' self.host_name = u'' class InstanceEndpoints(WindowsAzureData): def __init__(self): self.instance_endpoints = _list_of(InstanceEndpoint) def __iter__(self): return iter(self.instance_endpoints) def __len__(self): return len(self.instance_endpoints) def __getitem__(self, index): return self.instance_endpoints[index] class InstanceEndpoint(WindowsAzureData): def __init__(self): self.name = u'' self.vip = u'' self.public_port = u'' self.local_port = u'' self.protocol = u'' class UpgradeStatus(WindowsAzureData): def __init__(self): self.upgrade_type = u'' self.current_upgrade_domain_state = u'' self.current_upgrade_domain = u'' class InputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of(InputEndpoint) def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class InputEndpoint(WindowsAzureData): def __init__(self): self.role_name = u'' self.vip = u'' self.port = u'' class RoleList(WindowsAzureData): def __init__(self): self.roles = _list_of(Role) def __iter__(self): return iter(self.roles) def __len__(self): return len(self.roles) def __getitem__(self, index): return self.roles[index] class Role(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class PersistentVMDowntimeInfo(WindowsAzureData): def __init__(self): self.start_time = u'' self.end_time = u'' self.status = u'' class Certificates(WindowsAzureData): def __init__(self): self.certificates = _list_of(Certificate) def __iter__(self): return iter(self.certificates) def __len__(self): return len(self.certificates) def __getitem__(self, index): return self.certificates[index] class Certificate(WindowsAzureData): def __init__(self): self.certificate_url = u'' self.thumbprint = u'' self.thumbprint_algorithm = u'' self.data = u'' class OperationError(WindowsAzureData): def __init__(self): self.code = u'' self.message = u'' class Operation(WindowsAzureData): def __init__(self): self.id = u'' self.status = u'' self.http_status_code = u'' self.error = OperationError() class OperatingSystem(WindowsAzureData): def __init__(self): self.version = u'' self.label = _Base64String() self.is_default = True self.is_active = True self.family = 0 self.family_label = _Base64String() class OperatingSystems(WindowsAzureData): def __init__(self): self.operating_systems = _list_of(OperatingSystem) def __iter__(self): return iter(self.operating_systems) def __len__(self): return len(self.operating_systems) def __getitem__(self, index): return self.operating_systems[index] class OperatingSystemFamily(WindowsAzureData): def __init__(self): self.name = u'' self.label = _Base64String() self.operating_systems = OperatingSystems() class OperatingSystemFamilies(WindowsAzureData): def __init__(self): self.operating_system_families = _list_of(OperatingSystemFamily) def __iter__(self): return iter(self.operating_system_families) def __len__(self): return len(self.operating_system_families) def __getitem__(self, index): return self.operating_system_families[index] class Subscription(WindowsAzureData): def __init__(self): self.subscription_id = u'' self.subscription_name = u'' self.subscription_status = u'' self.account_admin_live_email_id = u'' self.service_admin_live_email_id = u'' self.max_core_count = 0 self.max_storage_accounts = 0 self.max_hosted_services = 0 self.current_core_count = 0 self.current_hosted_services = 0 self.current_storage_accounts = 0 self.max_virtual_network_sites = 0 self.max_local_network_sites = 0 self.max_dns_servers = 0 class AvailabilityResponse(WindowsAzureData): def __init__(self): self.result = False class SubscriptionCertificates(WindowsAzureData): def __init__(self): self.subscription_certificates = _list_of(SubscriptionCertificate) def __iter__(self): return iter(self.subscription_certificates) def __len__(self): return len(self.subscription_certificates) def __getitem__(self, index): return self.subscription_certificates[index] class SubscriptionCertificate(WindowsAzureData): def __init__(self): self.subscription_certificate_public_key = u'' self.subscription_certificate_thumbprint = u'' self.subscription_certificate_data = u'' self.created = u'' class Images(WindowsAzureData): def __init__(self): self.images = _list_of(OSImage) def __iter__(self): return iter(self.images) def __len__(self): return len(self.images) def __getitem__(self, index): return self.images[index] class OSImage(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.category = u'' self.location = u'' self.logical_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.eula = u'' self.description = u'' class Disks(WindowsAzureData): def __init__(self): self.disks = _list_of(Disk) def __iter__(self): return iter(self.disks) def __len__(self): return len(self.disks) def __getitem__(self, index): return self.disks[index] class Disk(WindowsAzureData): def __init__(self): self.affinity_group = u'' self.attached_to = AttachedTo() self.has_operating_system = u'' self.is_corrupted = u'' self.location = u'' self.logical_disk_size_in_gb = 0 self.label = u'' self.media_link = u'' self.name = u'' self.os = u'' self.source_image_name = u'' class AttachedTo(WindowsAzureData): def __init__(self): self.hosted_service_name = u'' self.deployment_name = u'' self.role_name = u'' class PersistentVMRole(WindowsAzureData): def __init__(self): self.role_name = u'' self.role_type = u'' self.os_version = u'' # undocumented self.configuration_sets = ConfigurationSets() self.availability_set_name = u'' self.data_virtual_hard_disks = DataVirtualHardDisks() self.os_virtual_hard_disk = OSVirtualHardDisk() self.role_size = u'' self.default_win_rm_certificate_thumbprint = u'' class ConfigurationSets(WindowsAzureData): def __init__(self): self.configuration_sets = _list_of(ConfigurationSet) def __iter__(self): return iter(self.configuration_sets) def __len__(self): return len(self.configuration_sets) def __getitem__(self, index): return self.configuration_sets[index] class ConfigurationSet(WindowsAzureData): def __init__(self): self.configuration_set_type = u'NetworkConfiguration' self.role_type = u'' self.input_endpoints = ConfigurationSetInputEndpoints() self.subnet_names = _scalar_list_of(str, 'SubnetName') class ConfigurationSetInputEndpoints(WindowsAzureData): def __init__(self): self.input_endpoints = _list_of( ConfigurationSetInputEndpoint, 'InputEndpoint') def __iter__(self): return iter(self.input_endpoints) def __len__(self): return len(self.input_endpoints) def __getitem__(self, index): return self.input_endpoints[index] class ConfigurationSetInputEndpoint(WindowsAzureData): ''' Initializes a network configuration input endpoint. name: Specifies the name for the external endpoint. protocol: Specifies the protocol to use to inspect the virtual machine availability status. Possible values are: HTTP, TCP. port: Specifies the external port to use for the endpoint. local_port: Specifies the internal port on which the virtual machine is listening to serve the endpoint. load_balanced_endpoint_set_name: Specifies a name for a set of load-balanced endpoints. Specifying this element for a given endpoint adds it to the set. If you are setting an endpoint to use to connect to the virtual machine via the Remote Desktop, do not set this property. enable_direct_server_return: Specifies whether direct server return load balancing is enabled. ''' def __init__(self, name=u'', protocol=u'', port=u'', local_port=u'', load_balanced_endpoint_set_name=u'', enable_direct_server_return=False): self.enable_direct_server_return = enable_direct_server_return self.load_balanced_endpoint_set_name = load_balanced_endpoint_set_name self.local_port = local_port self.name = name self.port = port self.load_balancer_probe = LoadBalancerProbe() self.protocol = protocol class WindowsConfigurationSet(WindowsAzureData): def __init__(self, computer_name=None, admin_password=None, reset_password_on_first_logon=None, enable_automatic_updates=None, time_zone=None, admin_username=None): self.configuration_set_type = u'WindowsProvisioningConfiguration' self.computer_name = computer_name self.admin_password = admin_password self.admin_username = admin_username self.reset_password_on_first_logon = reset_password_on_first_logon self.enable_automatic_updates = enable_automatic_updates self.time_zone = time_zone self.domain_join = DomainJoin() self.stored_certificate_settings = StoredCertificateSettings() self.win_rm = WinRM() class DomainJoin(WindowsAzureData): def __init__(self): self.credentials = Credentials() self.join_domain = u'' self.machine_object_ou = u'' class Credentials(WindowsAzureData): def __init__(self): self.domain = u'' self.username = u'' self.password = u'' class StoredCertificateSettings(WindowsAzureData): def __init__(self): self.stored_certificate_settings = _list_of(CertificateSetting) def __iter__(self): return iter(self.stored_certificate_settings) def __len__(self): return len(self.stored_certificate_settings) def __getitem__(self, index): return self.stored_certificate_settings[index] class CertificateSetting(WindowsAzureData): ''' Initializes a certificate setting. thumbprint: Specifies the thumbprint of the certificate to be provisioned. The thumbprint must specify an existing service certificate. store_name: Specifies the name of the certificate store from which retrieve certificate. store_location: Specifies the target certificate store location on the virtual machine. The only supported value is LocalMachine. ''' def __init__(self, thumbprint=u'', store_name=u'', store_location=u''): self.thumbprint = thumbprint self.store_name = store_name self.store_location = store_location class WinRM(WindowsAzureData): ''' Contains configuration settings for the Windows Remote Management service on the Virtual Machine. ''' def __init__(self): self.listeners = Listeners() class Listeners(WindowsAzureData): def __init__(self): self.listeners = _list_of(Listener) def __iter__(self): return iter(self.listeners) def __len__(self): return len(self.listeners) def __getitem__(self, index): return self.listeners[index] class Listener(WindowsAzureData): ''' Specifies the protocol and certificate information for the listener. protocol: Specifies the protocol of listener. Possible values are: Http, Https. The value is case sensitive. certificate_thumbprint: Optional. Specifies the certificate thumbprint for the secure connection. If this value is not specified, a self-signed certificate is generated and used for the Virtual Machine. ''' def __init__(self, protocol=u'', certificate_thumbprint=u''): self.protocol = protocol self.certificate_thumbprint = certificate_thumbprint class LinuxConfigurationSet(WindowsAzureData): def __init__(self, host_name=None, user_name=None, user_password=None, disable_ssh_password_authentication=None): self.configuration_set_type = u'LinuxProvisioningConfiguration' self.host_name = host_name self.user_name = user_name self.user_password = user_password self.disable_ssh_password_authentication =\ disable_ssh_password_authentication self.ssh = SSH() class SSH(WindowsAzureData): def __init__(self): self.public_keys = PublicKeys() self.key_pairs = KeyPairs() class PublicKeys(WindowsAzureData): def __init__(self): self.public_keys = _list_of(PublicKey) def __iter__(self): return iter(self.public_keys) def __len__(self): return len(self.public_keys) def __getitem__(self, index): return self.public_keys[index] class PublicKey(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class KeyPairs(WindowsAzureData): def __init__(self): self.key_pairs = _list_of(KeyPair) def __iter__(self): return iter(self.key_pairs) def __len__(self): return len(self.key_pairs) def __getitem__(self, index): return self.key_pairs[index] class KeyPair(WindowsAzureData): def __init__(self, fingerprint=u'', path=u''): self.fingerprint = fingerprint self.path = path class LoadBalancerProbe(WindowsAzureData): def __init__(self): self.path = u'' self.port = u'' self.protocol = u'' class DataVirtualHardDisks(WindowsAzureData): def __init__(self): self.data_virtual_hard_disks = _list_of(DataVirtualHardDisk) def __iter__(self): return iter(self.data_virtual_hard_disks) def __len__(self): return len(self.data_virtual_hard_disks) def __getitem__(self, index): return self.data_virtual_hard_disks[index] class DataVirtualHardDisk(WindowsAzureData): def __init__(self): self.host_caching = u'' self.disk_label = u'' self.disk_name = u'' self.lun = 0 self.logical_disk_size_in_gb = 0 self.media_link = u'' class OSVirtualHardDisk(WindowsAzureData): def __init__(self, source_image_name=None, media_link=None, host_caching=None, disk_label=None, disk_name=None): self.source_image_name = source_image_name self.media_link = media_link self.host_caching = host_caching self.disk_label = disk_label self.disk_name = disk_name self.os = u'' # undocumented, not used when adding a role class AsynchronousOperationResult(WindowsAzureData): def __init__(self, request_id=None): self.request_id = request_id class ServiceBusRegion(WindowsAzureData): def __init__(self): self.code = u'' self.fullname = u'' class ServiceBusNamespace(WindowsAzureData): def __init__(self): self.name = u'' self.region = u'' self.default_key = u'' self.status = u'' self.created_at = u'' self.acs_management_endpoint = u'' self.servicebus_endpoint = u'' self.connection_string = u'' self.subscription_id = u'' self.enabled = False class WebSpaces(WindowsAzureData): def __init__(self): self.web_space = _list_of(WebSpace) def __iter__(self): return iter(self.web_space) def __len__(self): return len(self.web_space) def __getitem__(self, index): return self.web_space[index] class WebSpace(WindowsAzureData): def __init__(self): self.availability_state = u'' self.geo_location = u'' self.geo_region = u'' self.name = u'' self.plan = u'' self.status = u'' self.subscription = u'' class Sites(WindowsAzureData): def __init__(self): self.site = _list_of(Site) def __iter__(self): return iter(self.site) def __len__(self): return len(self.site) def __getitem__(self, index): return self.site[index] class Site(WindowsAzureData): def __init__(self): self.admin_enabled = False self.availability_state = '' self.compute_mode = '' self.enabled = False self.enabled_host_names = _scalar_list_of(str, 'a:string') self.host_name_ssl_states = HostNameSslStates() self.host_names = _scalar_list_of(str, 'a:string') self.last_modified_time_utc = '' self.name = '' self.repository_site_name = '' self.self_link = '' self.server_farm = '' self.site_mode = '' self.state = '' self.storage_recovery_default_state = '' self.usage_state = '' self.web_space = '' class HostNameSslStates(WindowsAzureData): def __init__(self): self.host_name_ssl_state = _list_of(HostNameSslState) def __iter__(self): return iter(self.host_name_ssl_state) def __len__(self): return len(self.host_name_ssl_state) def __getitem__(self, index): return self.host_name_ssl_state[index] class HostNameSslState(WindowsAzureData): def __init__(self): self.name = u'' self.ssl_state = u'' class PublishData(WindowsAzureData): _xml_name = 'publishData' def __init__(self): self.publish_profiles = _list_of(PublishProfile, 'publishProfile') class PublishProfile(WindowsAzureData): def __init__(self): self.profile_name = _xml_attribute('profileName') self.publish_method = _xml_attribute('publishMethod') self.publish_url = _xml_attribute('publishUrl') self.msdeploysite = _xml_attribute('msdeploySite') self.user_name = _xml_attribute('userName') self.user_pwd = _xml_attribute('userPWD') self.destination_app_url = _xml_attribute('destinationAppUrl') self.sql_server_db_connection_string = _xml_attribute('SQLServerDBConnectionString') self.my_sqldb_connection_string = _xml_attribute('mySQLDBConnectionString') self.hosting_provider_forum_link = _xml_attribute('hostingProviderForumLink') self.control_panel_link = _xml_attribute('controlPanelLink') class QueueDescription(WindowsAzureData): def __init__(self): self.lock_duration = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.requires_session = False self.default_message_time_to_live = u'' self.dead_lettering_on_message_expiration = False self.duplicate_detection_history_time_window = u'' self.max_delivery_count = 0 self.enable_batched_operations = False self.size_in_bytes = 0 self.message_count = 0 self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.auto_delete_on_idle = u'' self.count_details = CountDetails() self.entity_availability_status = u'' class TopicDescription(WindowsAzureData): def __init__(self): self.default_message_time_to_live = u'' self.max_size_in_megabytes = 0 self.requires_duplicate_detection = False self.duplicate_detection_history_time_window = u'' self.enable_batched_operations = False self.size_in_bytes = 0 self.filtering_messages_before_publishing = False self.is_anonymous_accessible = False self.authorization_rules = AuthorizationRules() self.status = u'' self.created_at = u'' self.updated_at = u'' self.accessed_at = u'' self.support_ordering = False self.count_details = CountDetails() self.subscription_count = 0 class CountDetails(WindowsAzureData): def __init__(self): self.active_message_count = 0 self.dead_letter_message_count = 0 self.scheduled_message_count = 0 self.transfer_message_count = 0 self.transfer_dead_letter_message_count = 0 class NotificationHubDescription(WindowsAzureData): def __init__(self): self.registration_ttl = u'' self.authorization_rules = AuthorizationRules() class AuthorizationRules(WindowsAzureData): def __init__(self): self.authorization_rule = _list_of(AuthorizationRule) def __iter__(self): return iter(self.authorization_rule) def __len__(self): return len(self.authorization_rule) def __getitem__(self, index): return self.authorization_rule[index] class AuthorizationRule(WindowsAzureData): def __init__(self): self.claim_type = u'' self.claim_value = u'' self.rights = _scalar_list_of(str, 'AccessRights') self.created_time = u'' self.modified_time = u'' self.key_name = u'' self.primary_key = u'' self.secondary_keu = u'' class RelayDescription(WindowsAzureData): def __init__(self): self.path = u'' self.listener_type = u'' self.listener_count = 0 self.created_at = u'' self.updated_at = u'' class MetricResponses(WindowsAzureData): def __init__(self): self.metric_response = _list_of(MetricResponse) def __iter__(self): return iter(self.metric_response) def __len__(self): return len(self.metric_response) def __getitem__(self, index): return self.metric_response[index] class MetricResponse(WindowsAzureData): def __init__(self): self.code = u'' self.data = Data() self.message = u'' class Data(WindowsAzureData): def __init__(self): self.display_name = u'' self.end_time = u'' self.name = u'' self.primary_aggregation_type = u'' self.start_time = u'' self.time_grain = u'' self.unit = u'' self.values = Values() class Values(WindowsAzureData): def __init__(self): self.metric_sample = _list_of(MetricSample) def __iter__(self): return iter(self.metric_sample) def __len__(self): return len(self.metric_sample) def __getitem__(self, index): return self.metric_sample[index] class MetricSample(WindowsAzureData): def __init__(self): self.count = 0 self.time_created = u'' self.total = 0 class MetricDefinitions(WindowsAzureData): def __init__(self): self.metric_definition = _list_of(MetricDefinition) def __iter__(self): return iter(self.metric_definition) def __len__(self): return len(self.metric_definition) def __getitem__(self, index): return self.metric_definition[index] class MetricDefinition(WindowsAzureData): def __init__(self): self.display_name = u'' self.metric_availabilities = MetricAvailabilities() self.name = u'' self.primary_aggregation_type = u'' self.unit = u'' class MetricAvailabilities(WindowsAzureData): def __init__(self): self.metric_availability = _list_of(MetricAvailability, 'MetricAvailabilily') def __iter__(self): return iter(self.metric_availability) def __len__(self): return len(self.metric_availability) def __getitem__(self, index): return self.metric_availability[index] class MetricAvailability(WindowsAzureData): def __init__(self): self.retention = u'' self.time_grain = u'' class Servers(WindowsAzureData): def __init__(self): self.server = _list_of(Server) def __iter__(self): return iter(self.server) def __len__(self): return len(self.server) def __getitem__(self, index): return self.server[index] class Server(WindowsAzureData): def __init__(self): self.name = u'' self.administrator_login = u'' self.location = u'' self.fully_qualified_domain_name = u'' self.version = u'' class Database(WindowsAzureData): def __init__(self): self.name = u'' self.type = u'' self.state = u'' self.self_link = u'' self.parent_link = u'' self.id = 0 self.edition = u'' self.collation_name = u'' self.creation_date = u'' self.is_federation_root = False self.is_system_object = False self.max_size_bytes = 0 def _update_management_header(request): ''' Add additional headers for management. ''' if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append additional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # if it is not GET or HEAD request, must set content-type. if not request.method in ['GET', 'HEAD']: for name, _ in request.headers: if 'content-type' == name.lower(): break else: request.headers.append( ('Content-Type', 'application/atom+xml;type=entry;charset=utf-8')) return request.headers def _parse_response_for_async_op(response): ''' Extracts request id from response header. ''' if response is None: return None result = AsynchronousOperationResult() if response.headers: for name, value in response.headers: if name.lower() == 'x-ms-request-id': result.request_id = value return result def _management_error_handler(http_error): ''' Simple error handler for management service. ''' return _general_error_handler(http_error) def _lower(text): return text.lower() class _XmlSerializer(object): @staticmethod def create_storage_service_input_to_xml(service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'CreateStorageServiceInput', [('ServiceName', service_name), ('Description', description), ('Label', label, _encode_base64), ('AffinityGroup', affinity_group), ('Location', location), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def update_storage_service_input_to_xml(description, label, geo_replication_enabled, extended_properties): return _XmlSerializer.doc_from_data( 'UpdateStorageServiceInput', [('Description', description), ('Label', label, _encode_base64), ('GeoReplicationEnabled', geo_replication_enabled, _lower)], extended_properties) @staticmethod def regenerate_keys_to_xml(key_type): return _XmlSerializer.doc_from_data('RegenerateKeys', [('KeyType', key_type)]) @staticmethod def update_hosted_service_to_xml(label, description, extended_properties): return _XmlSerializer.doc_from_data('UpdateHostedService', [('Label', label, _encode_base64), ('Description', description)], extended_properties) @staticmethod def create_hosted_service_to_xml(service_name, label, description, location, affinity_group, extended_properties): return _XmlSerializer.doc_from_data( 'CreateHostedService', [('ServiceName', service_name), ('Label', label, _encode_base64), ('Description', description), ('Location', location), ('AffinityGroup', affinity_group)], extended_properties) @staticmethod def create_deployment_to_xml(name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties): return _XmlSerializer.doc_from_data( 'CreateDeployment', [('Name', name), ('PackageUrl', package_url), ('Label', label, _encode_base64), ('Configuration', configuration), ('StartDeployment', start_deployment, _lower), ('TreatWarningsAsError', treat_warnings_as_error, _lower)], extended_properties) @staticmethod def swap_deployment_to_xml(production, source_deployment): return _XmlSerializer.doc_from_data( 'Swap', [('Production', production), ('SourceDeployment', source_deployment)]) @staticmethod def update_deployment_status_to_xml(status): return _XmlSerializer.doc_from_data( 'UpdateDeploymentStatus', [('Status', status)]) @staticmethod def change_deployment_to_xml(configuration, treat_warnings_as_error, mode, extended_properties): return _XmlSerializer.doc_from_data( 'ChangeConfiguration', [('Configuration', configuration), ('TreatWarningsAsError', treat_warnings_as_error, _lower), ('Mode', mode)], extended_properties) @staticmethod def upgrade_deployment_to_xml(mode, package_url, configuration, label, role_to_upgrade, force, extended_properties): return _XmlSerializer.doc_from_data( 'UpgradeDeployment', [('Mode', mode), ('PackageUrl', package_url), ('Configuration', configuration), ('Label', label, _encode_base64), ('RoleToUpgrade', role_to_upgrade), ('Force', force, _lower)], extended_properties) @staticmethod def rollback_upgrade_to_xml(mode, force): return _XmlSerializer.doc_from_data( 'RollbackUpdateOrUpgrade', [('Mode', mode), ('Force', force, _lower)]) @staticmethod def walk_upgrade_domain_to_xml(upgrade_domain): return _XmlSerializer.doc_from_data( 'WalkUpgradeDomain', [('UpgradeDomain', upgrade_domain)]) @staticmethod def certificate_file_to_xml(data, certificate_format, password): return _XmlSerializer.doc_from_data( 'CertificateFile', [('Data', data), ('CertificateFormat', certificate_format), ('Password', password)]) @staticmethod def create_affinity_group_to_xml(name, label, description, location): return _XmlSerializer.doc_from_data( 'CreateAffinityGroup', [('Name', name), ('Label', label, _encode_base64), ('Description', description), ('Location', location)]) @staticmethod def update_affinity_group_to_xml(label, description): return _XmlSerializer.doc_from_data( 'UpdateAffinityGroup', [('Label', label, _encode_base64), ('Description', description)]) @staticmethod def subscription_certificate_to_xml(public_key, thumbprint, data): return _XmlSerializer.doc_from_data( 'SubscriptionCertificate', [('SubscriptionCertificatePublicKey', public_key), ('SubscriptionCertificateThumbprint', thumbprint), ('SubscriptionCertificateData', data)]) @staticmethod def os_image_to_xml(label, media_link, name, os): return _XmlSerializer.doc_from_data( 'OSImage', [('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def data_virtual_hard_disk_to_xml(host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link): return _XmlSerializer.doc_from_data( 'DataVirtualHardDisk', [('HostCaching', host_caching), ('DiskLabel', disk_label), ('DiskName', disk_name), ('Lun', lun), ('LogicalDiskSizeInGB', logical_disk_size_in_gb), ('MediaLink', media_link), ('SourceMediaLink', source_media_link)]) @staticmethod def disk_to_xml(has_operating_system, label, media_link, name, os): return _XmlSerializer.doc_from_data( 'Disk', [('HasOperatingSystem', has_operating_system, _lower), ('Label', label), ('MediaLink', media_link), ('Name', name), ('OS', os)]) @staticmethod def restart_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'RestartRoleOperation', '<OperationType>RestartRoleOperation</OperationType>') @staticmethod def shutdown_role_operation_to_xml(post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRoleOperation'), ('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRoleOperation', xml) @staticmethod def shutdown_roles_operation_to_xml(role_names, post_shutdown_action): xml = _XmlSerializer.data_to_xml( [('OperationType', 'ShutdownRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' xml += _XmlSerializer.data_to_xml( [('PostShutdownAction', post_shutdown_action)]) return _XmlSerializer.doc_from_xml('ShutdownRolesOperation', xml) @staticmethod def start_role_operation_to_xml(): return _XmlSerializer.doc_from_xml( 'StartRoleOperation', '<OperationType>StartRoleOperation</OperationType>') @staticmethod def start_roles_operation_to_xml(role_names): xml = _XmlSerializer.data_to_xml( [('OperationType', 'StartRolesOperation')]) xml += '<Roles>' for role_name in role_names: xml += _XmlSerializer.data_to_xml([('Name', role_name)]) xml += '</Roles>' return _XmlSerializer.doc_from_xml('StartRolesOperation', xml) @staticmethod def windows_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('ComputerName', configuration.computer_name), ('AdminPassword', configuration.admin_password), ('ResetPasswordOnFirstLogon', configuration.reset_password_on_first_logon, _lower), ('EnableAutomaticUpdates', configuration.enable_automatic_updates, _lower), ('TimeZone', configuration.time_zone)]) if configuration.domain_join is not None: xml += '<DomainJoin>' xml += '<Credentials>' xml += _XmlSerializer.data_to_xml( [('Domain', configuration.domain_join.credentials.domain), ('Username', configuration.domain_join.credentials.username), ('Password', configuration.domain_join.credentials.password)]) xml += '</Credentials>' xml += _XmlSerializer.data_to_xml( [('JoinDomain', configuration.domain_join.join_domain), ('MachineObjectOU', configuration.domain_join.machine_object_ou)]) xml += '</DomainJoin>' if configuration.stored_certificate_settings is not None: xml += '<StoredCertificateSettings>' for cert in configuration.stored_certificate_settings: xml += '<CertificateSetting>' xml += _XmlSerializer.data_to_xml( [('StoreLocation', cert.store_location), ('StoreName', cert.store_name), ('Thumbprint', cert.thumbprint)]) xml += '</CertificateSetting>' xml += '</StoredCertificateSettings>' if configuration.win_rm is not None: xml += '<WinRM><Listeners>' for listener in configuration.win_rm.listeners: xml += '<Listener>' xml += _XmlSerializer.data_to_xml( [('Protocol', listener.protocol), ('CertificateThumbprint', listener.certificate_thumbprint)]) xml += '</Listener>' xml += '</Listeners></WinRM>' xml += _XmlSerializer.data_to_xml( [('AdminUsername', configuration.admin_username)]) return xml @staticmethod def linux_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type), ('HostName', configuration.host_name), ('UserName', configuration.user_name), ('UserPassword', configuration.user_password), ('DisableSshPasswordAuthentication', configuration.disable_ssh_password_authentication, _lower)]) if configuration.ssh is not None: xml += '<SSH>' xml += '<PublicKeys>' for key in configuration.ssh.public_keys: xml += '<PublicKey>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</PublicKey>' xml += '</PublicKeys>' xml += '<KeyPairs>' for key in configuration.ssh.key_pairs: xml += '<KeyPair>' xml += _XmlSerializer.data_to_xml( [('Fingerprint', key.fingerprint), ('Path', key.path)]) xml += '</KeyPair>' xml += '</KeyPairs>' xml += '</SSH>' return xml @staticmethod def network_configuration_to_xml(configuration): xml = _XmlSerializer.data_to_xml( [('ConfigurationSetType', configuration.configuration_set_type)]) xml += '<InputEndpoints>' for endpoint in configuration.input_endpoints: xml += '<InputEndpoint>' xml += _XmlSerializer.data_to_xml( [('LoadBalancedEndpointSetName', endpoint.load_balanced_endpoint_set_name), ('LocalPort', endpoint.local_port), ('Name', endpoint.name), ('Port', endpoint.port)]) if endpoint.load_balancer_probe.path or\ endpoint.load_balancer_probe.port or\ endpoint.load_balancer_probe.protocol: xml += '<LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Path', endpoint.load_balancer_probe.path), ('Port', endpoint.load_balancer_probe.port), ('Protocol', endpoint.load_balancer_probe.protocol)]) xml += '</LoadBalancerProbe>' xml += _XmlSerializer.data_to_xml( [('Protocol', endpoint.protocol), ('EnableDirectServerReturn', endpoint.enable_direct_server_return, _lower)]) xml += '</InputEndpoint>' xml += '</InputEndpoints>' xml += '<SubnetNames>' for name in configuration.subnet_names: xml += _XmlSerializer.data_to_xml([('SubnetName', name)]) xml += '</SubnetNames>' return xml @staticmethod def role_to_xml(availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set): xml = _XmlSerializer.data_to_xml([('RoleName', role_name), ('RoleType', role_type)]) xml += '<ConfigurationSets>' if system_configuration_set is not None: xml += '<ConfigurationSet>' if isinstance(system_configuration_set, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( system_configuration_set) elif isinstance(system_configuration_set, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( system_configuration_set) xml += '</ConfigurationSet>' if network_configuration_set is not None: xml += '<ConfigurationSet>' xml += _XmlSerializer.network_configuration_to_xml( network_configuration_set) xml += '</ConfigurationSet>' xml += '</ConfigurationSets>' if availability_set_name is not None: xml += _XmlSerializer.data_to_xml( [('AvailabilitySetName', availability_set_name)]) if data_virtual_hard_disks is not None: xml += '<DataVirtualHardDisks>' for hd in data_virtual_hard_disks: xml += '<DataVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', hd.host_caching), ('DiskLabel', hd.disk_label), ('DiskName', hd.disk_name), ('Lun', hd.lun), ('LogicalDiskSizeInGB', hd.logical_disk_size_in_gb), ('MediaLink', hd.media_link)]) xml += '</DataVirtualHardDisk>' xml += '</DataVirtualHardDisks>' if os_virtual_hard_disk is not None: xml += '<OSVirtualHardDisk>' xml += _XmlSerializer.data_to_xml( [('HostCaching', os_virtual_hard_disk.host_caching), ('DiskLabel', os_virtual_hard_disk.disk_label), ('DiskName', os_virtual_hard_disk.disk_name), ('MediaLink', os_virtual_hard_disk.media_link), ('SourceImageName', os_virtual_hard_disk.source_image_name)]) xml += '</OSVirtualHardDisk>' if role_size is not None: xml += _XmlSerializer.data_to_xml([('RoleSize', role_size)]) return xml @staticmethod def add_role_to_xml(role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def update_role_to_xml(role_name, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size): xml = _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, None) return _XmlSerializer.doc_from_xml('PersistentVMRole', xml) @staticmethod def capture_role_to_xml(post_capture_action, target_image_name, target_image_label, provisioning_configuration): xml = _XmlSerializer.data_to_xml( [('OperationType', 'CaptureRoleOperation'), ('PostCaptureAction', post_capture_action)]) if provisioning_configuration is not None: xml += '<ProvisioningConfiguration>' if isinstance(provisioning_configuration, WindowsConfigurationSet): xml += _XmlSerializer.windows_configuration_to_xml( provisioning_configuration) elif isinstance(provisioning_configuration, LinuxConfigurationSet): xml += _XmlSerializer.linux_configuration_to_xml( provisioning_configuration) xml += '</ProvisioningConfiguration>' xml += _XmlSerializer.data_to_xml( [('TargetImageLabel', target_image_label), ('TargetImageName', target_image_name)]) return _XmlSerializer.doc_from_xml('CaptureRoleOperation', xml) @staticmethod def virtual_machine_deployment_to_xml(deployment_name, deployment_slot, label, role_name, system_configuration_set, os_virtual_hard_disk, role_type, network_configuration_set, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name): xml = _XmlSerializer.data_to_xml([('Name', deployment_name), ('DeploymentSlot', deployment_slot), ('Label', label)]) xml += '<RoleList>' xml += '<Role>' xml += _XmlSerializer.role_to_xml( availability_set_name, data_virtual_hard_disks, network_configuration_set, os_virtual_hard_disk, role_name, role_size, role_type, system_configuration_set) xml += '</Role>' xml += '</RoleList>' if virtual_network_name is not None: xml += _XmlSerializer.data_to_xml( [('VirtualNetworkName', virtual_network_name)]) return _XmlSerializer.doc_from_xml('Deployment', xml) @staticmethod def create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode): xml = '<HostNames xmlns:a="http://schemas.microsoft.com/2003/10/Serialization/Arrays">' for host_name in host_names: xml += '<a:string>{0}</a:string>'.format(host_name) xml += '</HostNames>' xml += _XmlSerializer.data_to_xml( [('Name', website_name), ('ComputeMode', compute_mode), ('ServerFarm', server_farm), ('SiteMode', site_mode)]) xml += '<WebSpaceToCreate>' xml += _XmlSerializer.data_to_xml( [('GeoRegion', geo_region), ('Name', webspace_name), ('Plan', plan)]) xml += '</WebSpaceToCreate>' return _XmlSerializer.doc_from_xml('Site', xml) @staticmethod def data_to_xml(data): '''Creates an xml fragment from the specified data. data: Array of tuples, where first: xml element name second: xml element text third: conversion function ''' xml = '' for element in data: name = element[0] val = element[1] if len(element) > 2: converter = element[2] else: converter = None if val is not None: if converter is not None: text = _str(converter(_str(val))) else: text = _str(val) xml += ''.join(['<', name, '>', text, '</', name, '>']) return xml @staticmethod def doc_from_xml(document_element_name, inner_xml): '''Wraps the specified xml in an xml root element with default azure namespaces''' xml = ''.join(['<', document_element_name, ' xmlns:i="http://www.w3.org/2001/XMLSchema-instance"', ' xmlns="http://schemas.microsoft.com/windowsazure">']) xml += inner_xml xml += ''.join(['</', document_element_name, '>']) return xml @staticmethod def doc_from_data(document_element_name, data, extended_properties=None): xml = _XmlSerializer.data_to_xml(data) if extended_properties is not None: xml += _XmlSerializer.extended_properties_dict_to_xml_fragment( extended_properties) return _XmlSerializer.doc_from_xml(document_element_name, xml) @staticmethod def extended_properties_dict_to_xml_fragment(extended_properties): xml = '' if extended_properties is not None and len(extended_properties) > 0: xml += '<ExtendedProperties>' for key, val in extended_properties.items(): xml += ''.join(['<ExtendedProperty>', '<Name>', _str(key), '</Name>', '<Value>', _str(val), '</Value>', '</ExtendedProperty>']) xml += '</ExtendedProperties>' return xml def _parse_bool(value): if value.lower() == 'true': return True return False class _ServiceBusManagementXmlSerializer(object): @staticmethod def namespace_to_xml(region): '''Converts a service bus namespace description to xml The xml format: <?xml version="1.0" encoding="utf-8" standalone="yes"?> <entry xmlns="http://www.w3.org/2005/Atom"> <content type="application/xml"> <NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect"> <Region>West US</Region> </NamespaceDescription> </content> </entry> ''' body = '<NamespaceDescription xmlns="http://schemas.microsoft.com/netservices/2010/10/servicebus/connect">' body += ''.join(['<Region>', region, '</Region>']) body += '</NamespaceDescription>' return _create_entry(body) @staticmethod def xml_to_namespace(xmlstr): '''Converts xml response to service bus namespace The xml format for namespace: <entry> <id>uuid:00000000-0000-0000-0000-000000000000;id=0000000</id> <title type="text">myunittests 2012-08-22T16:48:10Z myunittests West US 0000000000000000000000000000000000000000000= Active 2012-08-22T16:48:10.217Z https://myunittests-sb.accesscontrol.windows.net/ https://myunittests.servicebus.windows.net/ Endpoint=sb://myunittests.servicebus.windows.net/;SharedSecretIssuer=owner;SharedSecretValue=0000000000000000000000000000000000000000000= 00000000000000000000000000000000 true ''' xmldoc = minidom.parseString(xmlstr) namespace = ServiceBusNamespace() mappings = ( ('Name', 'name', None), ('Region', 'region', None), ('DefaultKey', 'default_key', None), ('Status', 'status', None), ('CreatedAt', 'created_at', None), ('AcsManagementEndpoint', 'acs_management_endpoint', None), ('ServiceBusEndpoint', 'servicebus_endpoint', None), ('ConnectionString', 'connection_string', None), ('SubscriptionId', 'subscription_id', None), ('Enabled', 'enabled', _parse_bool), ) for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceDescription'): for xml_name, field_name, conversion_func in mappings: node_value = _get_first_child_node_value(desc, xml_name) if node_value is not None: if conversion_func is not None: node_value = conversion_func(node_value) setattr(namespace, field_name, node_value) return namespace @staticmethod def xml_to_region(xmlstr): '''Converts xml response to service bus region The xml format for region: uuid:157c311f-081f-4b4a-a0ba-a8f990ffd2a3;id=1756759 2013-04-10T18:25:29Z East Asia East Asia ''' xmldoc = minidom.parseString(xmlstr) region = ServiceBusRegion() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'RegionCodeDescription'): node_value = _get_first_child_node_value(desc, 'Code') if node_value is not None: region.code = node_value node_value = _get_first_child_node_value(desc, 'FullName') if node_value is not None: region.fullname = node_value return region @staticmethod def xml_to_namespace_availability(xmlstr): '''Converts xml response to service bus namespace availability The xml format: uuid:9fc7c652-1856-47ab-8d74-cd31502ea8e6;id=3683292 2013-04-16T03:03:37Z false ''' xmldoc = minidom.parseString(xmlstr) availability = AvailabilityResponse() for desc in _get_children_from_path(xmldoc, 'entry', 'content', 'NamespaceAvailability'): node_value = _get_first_child_node_value(desc, 'Result') if node_value is not None: availability.result = _parse_bool(node_value) return availability from azure.servicemanagement.servicemanagementservice import ( ServiceManagementService) from azure.servicemanagement.servicebusmanagementservice import ( ServiceBusManagementService) from azure.servicemanagement.websitemanagementservice import ( WebsiteManagementService) ================================================ FILE: OSPatching/azure/servicemanagement/servicebusmanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _convert_response_to_feeds, _str, _validate_not_none, ) from azure.servicemanagement import ( _ServiceBusManagementXmlSerializer, QueueDescription, TopicDescription, NotificationHubDescription, RelayDescription, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceBusManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceBusManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for service bus ---------------------------------------- def get_regions(self): ''' Get list of available service bus regions. ''' response = self._perform_get( self._get_path('services/serviceBus/Regions/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_region) def list_namespaces(self): ''' List the service bus namespaces defined on the account. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces/', None), None) return _convert_response_to_feeds( response, _ServiceBusManagementXmlSerializer.xml_to_namespace) def get_namespace(self, name): ''' Get details about a specific namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_path('services/serviceBus/Namespaces', name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace( response.body) def create_namespace(self, name, region): ''' Create a new service bus namespace. name: Name of the service bus namespace to create. region: Region to create the namespace in. ''' _validate_not_none('name', name) return self._perform_put( self._get_path('services/serviceBus/Namespaces', name), _ServiceBusManagementXmlSerializer.namespace_to_xml(region)) def delete_namespace(self, name): ''' Delete a service bus namespace. name: Name of the service bus namespace to delete. ''' _validate_not_none('name', name) return self._perform_delete( self._get_path('services/serviceBus/Namespaces', name), None) def check_namespace_availability(self, name): ''' Checks to see if the specified service bus namespace is available, or if it has already been taken. name: Name of the service bus namespace to validate. ''' _validate_not_none('name', name) response = self._perform_get( self._get_path('services/serviceBus/CheckNamespaceAvailability', None) + '/?namespace=' + _str(name), None) return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability( response.body) def list_queues(self, name): ''' Enumerates the queues in the service namespace. name: Name of the service bus namespace. ''' _validate_not_none('name', name) response = self._perform_get( self._get_list_queues_path(name), None) return _convert_response_to_feeds(response, QueueDescription) def list_topics(self, name): ''' Retrieves the topics in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_topics_path(name), None) return _convert_response_to_feeds(response, TopicDescription) def list_notification_hubs(self, name): ''' Retrieves the notification hubs in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_notification_hubs_path(name), None) return _convert_response_to_feeds(response, NotificationHubDescription) def list_relays(self, name): ''' Retrieves the relays in the service namespace. name: Name of the service bus namespace. ''' response = self._perform_get( self._get_list_relays_path(name), None) return _convert_response_to_feeds(response, RelayDescription) #--Helper functions -------------------------------------------------- def _get_list_queues_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Queues' def _get_list_topics_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Topics' def _get_list_notification_hubs_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/NotificationHubs' def _get_list_relays_path(self, namespace_name): return self._get_path('services/serviceBus/Namespaces/', namespace_name) + '/Relays' ================================================ FILE: OSPatching/azure/servicemanagement/servicemanagementclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os from azure import ( WindowsAzureError, MANAGEMENT_HOST, _get_request_body, _parse_response, _str, _update_request_uri_query, ) from azure.http import ( HTTPError, HTTPRequest, ) from azure.http.httpclient import _HTTPClient from azure.servicemanagement import ( AZURE_MANAGEMENT_CERTFILE, AZURE_MANAGEMENT_SUBSCRIPTIONID, _management_error_handler, _parse_response_for_async_op, _update_management_header, ) class _ServiceManagementClient(object): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): self.requestid = None self.subscription_id = subscription_id self.cert_file = cert_file self.host = host if not self.cert_file: if AZURE_MANAGEMENT_CERTFILE in os.environ: self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE] if not self.subscription_id: if AZURE_MANAGEMENT_SUBSCRIPTIONID in os.environ: self.subscription_id = os.environ[ AZURE_MANAGEMENT_SUBSCRIPTIONID] if not self.cert_file or not self.subscription_id: raise WindowsAzureError( 'You need to provide subscription id and certificate file') self._httpclient = _HTTPClient( service_instance=self, cert_file=self.cert_file) self._filter = self._httpclient.perform_request def with_filter(self, filter): '''Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response.''' res = type(self)(self.subscription_id, self.cert_file, self.host) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) #--Helper functions -------------------------------------------------- def _perform_request(self, request): try: resp = self._filter(request) except HTTPError as ex: return _management_error_handler(ex) return resp def _perform_get(self, path, response_type): request = HTTPRequest() request.method = 'GET' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) return response def _perform_put(self, path, body, async=False): request = HTTPRequest() request.method = 'PUT' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _perform_post(self, path, body, response_type=None, async=False): request = HTTPRequest() request.method = 'POST' request.host = self.host request.path = path request.body = _get_request_body(body) request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if response_type is not None: return _parse_response(response, response_type) if async: return _parse_response_for_async_op(response) return None def _perform_delete(self, path, async=False): request = HTTPRequest() request.method = 'DELETE' request.host = self.host request.path = path request.path, request.query = _update_request_uri_query(request) request.headers = _update_management_header(request) response = self._perform_request(request) if async: return _parse_response_for_async_op(response) return None def _get_path(self, resource, name): path = '/' + self.subscription_id + '/' + resource if name is not None: path += '/' + _str(name) return path ================================================ FILE: OSPatching/azure/servicemanagement/servicemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, MANAGEMENT_HOST, _str, _validate_not_none, ) from azure.servicemanagement import ( AffinityGroups, AffinityGroup, AvailabilityResponse, Certificate, Certificates, DataVirtualHardDisk, Deployment, Disk, Disks, Locations, Operation, HostedService, HostedServices, Images, OperatingSystems, OperatingSystemFamilies, OSImage, PersistentVMRole, StorageService, StorageServices, Subscription, SubscriptionCertificate, SubscriptionCertificates, VirtualNetworkSites, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class ServiceManagementService(_ServiceManagementClient): def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(ServiceManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for storage accounts ----------------------------------- def list_storage_accounts(self): ''' Lists the storage accounts available under the current subscription. ''' return self._perform_get(self._get_storage_service_path(), StorageServices) def get_storage_account_properties(self, service_name): ''' Returns system properties for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get(self._get_storage_service_path(service_name), StorageService) def get_storage_account_keys(self, service_name): ''' Returns the primary and secondary access keys for the specified storage account. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path(service_name) + '/keys', StorageService) def regenerate_storage_account_keys(self, service_name, key_type): ''' Regenerates the primary or secondary access key for the specified storage account. service_name: Name of the storage service account. key_type: Specifies which key to regenerate. Valid values are: Primary, Secondary ''' _validate_not_none('service_name', service_name) _validate_not_none('key_type', key_type) return self._perform_post( self._get_storage_service_path( service_name) + '/keys?action=regenerate', _XmlSerializer.regenerate_keys_to_xml( key_type), StorageService) def create_storage_account(self, service_name, description, label, affinity_group=None, location=None, geo_replication_enabled=True, extended_properties=None): ''' Creates a new storage account in Windows Azure. service_name: A name for the storage account that is unique within Windows Azure. Storage account names must be between 3 and 24 characters in length and use numbers and lower-case letters only. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. affinity_group: The name of an existing affinity group in the specified subscription. You can specify either a location or affinity_group, but not both. location: The location where the storage account is created. You can specify either a location or affinity_group, but not both. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('description', description) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post( self._get_storage_service_path(), _XmlSerializer.create_storage_service_input_to_xml( service_name, description, label, affinity_group, location, geo_replication_enabled, extended_properties), async=True) def update_storage_account(self, service_name, description=None, label=None, geo_replication_enabled=None, extended_properties=None): ''' Updates the label, the description, and enables or disables the geo-replication status for a storage account in Windows Azure. service_name: Name of the storage service account. description: A description for the storage account. The description may be up to 1024 characters in length. label: A name for the storage account. The name may be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. geo_replication_enabled: Specifies whether the storage account is created with the geo-replication enabled. If the element is not included in the request body, the default value is true. If set to true, the data in the storage account is replicated across more than one geographic location so as to enable resilience in the face of catastrophic service loss. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put( self._get_storage_service_path(service_name), _XmlSerializer.update_storage_service_input_to_xml( description, label, geo_replication_enabled, extended_properties)) def delete_storage_account(self, service_name): ''' Deletes the specified storage account from Windows Azure. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_delete( self._get_storage_service_path(service_name)) def check_storage_account_name_availability(self, service_name): ''' Checks to see if the specified storage account name is available, or if it has already been taken. service_name: Name of the storage service account. ''' _validate_not_none('service_name', service_name) return self._perform_get( self._get_storage_service_path() + '/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for hosted services ------------------------------------ def list_hosted_services(self): ''' Lists the hosted services available under the current subscription. ''' return self._perform_get(self._get_hosted_service_path(), HostedServices) def get_hosted_service_properties(self, service_name, embed_detail=False): ''' Retrieves system properties for the specified hosted service. These properties include the service name and service type; the name of the affinity group to which the service belongs, or its location if it is not part of an affinity group; and optionally, information on the service's deployments. service_name: Name of the hosted service. embed_detail: When True, the management service returns properties for all deployments of the service, as well as for the service itself. ''' _validate_not_none('service_name', service_name) _validate_not_none('embed_detail', embed_detail) return self._perform_get( self._get_hosted_service_path(service_name) + '?embed-detail=' + _str(embed_detail).lower(), HostedService) def create_hosted_service(self, service_name, label, description=None, location=None, affinity_group=None, extended_properties=None): ''' Creates a new hosted service in Windows Azure. service_name: A name for the hosted service that is unique within Windows Azure. This name is the DNS prefix name and can be used to access the hosted service. label: A name for the hosted service. The name can be up to 100 characters in length. The name can be used to identify the storage account for your tracking purposes. description: A description for the hosted service. The description can be up to 1024 characters in length. location: The location where the hosted service will be created. You can specify either a location or affinity_group, but not both. affinity_group: The name of an existing affinity group associated with this subscription. This name is a GUID and can be retrieved by examining the name element of the response body returned by list_affinity_groups. You can specify either a location or affinity_group, but not both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('label', label) if affinity_group is None and location is None: raise WindowsAzureError( 'location or affinity_group must be specified') if affinity_group is not None and location is not None: raise WindowsAzureError( 'Only one of location or affinity_group needs to be specified') return self._perform_post(self._get_hosted_service_path(), _XmlSerializer.create_hosted_service_to_xml( service_name, label, description, location, affinity_group, extended_properties)) def update_hosted_service(self, service_name, label=None, description=None, extended_properties=None): ''' Updates the label and/or the description for a hosted service in Windows Azure. service_name: Name of the hosted service. label: A name for the hosted service. The name may be up to 100 characters in length. You must specify a value for either Label or Description, or for both. It is recommended that the label be unique within the subscription. The name can be used identify the hosted service for your tracking purposes. description: A description for the hosted service. The description may be up to 1024 characters in length. You must specify a value for either Label or Description, or for both. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) return self._perform_put(self._get_hosted_service_path(service_name), _XmlSerializer.update_hosted_service_to_xml( label, description, extended_properties)) def delete_hosted_service(self, service_name): ''' Deletes the specified hosted service from Windows Azure. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_delete(self._get_hosted_service_path(service_name)) def get_deployment_by_slot(self, service_name, deployment_slot): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) return self._perform_get( self._get_deployment_path_using_slot( service_name, deployment_slot), Deployment) def get_deployment_by_name(self, service_name, deployment_name): ''' Returns configuration information, status, and system properties for a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_get( self._get_deployment_path_using_name( service_name, deployment_name), Deployment) def create_deployment(self, service_name, deployment_slot, name, package_url, label, configuration, start_deployment=False, treat_warnings_as_error=False, extended_properties=None): ''' Uploads a new service package and creates a new deployment on staging or production. service_name: Name of the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. configuration: The base-64 encoded service configuration file for the deployment. start_deployment: Indicates whether to start the deployment immediately after it is created. If false, the service model is still deployed to the virtual machines but the code is not run immediately. Instead, the service is Suspended until you call Update Deployment Status and set the status to Running, at which time the service will be started. A deployed service still incurs charges, even if it is suspended. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('name', name) _validate_not_none('package_url', package_url) _validate_not_none('label', label) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_slot( service_name, deployment_slot), _XmlSerializer.create_deployment_to_xml( name, package_url, label, configuration, start_deployment, treat_warnings_as_error, extended_properties), async=True) def delete_deployment(self, service_name, deployment_name): ''' Deletes the specified deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) return self._perform_delete( self._get_deployment_path_using_name( service_name, deployment_name), async=True) def swap_deployment(self, service_name, production, source_deployment): ''' Initiates a virtual IP swap between the staging and production deployment environments for a service. If the service is currently running in the staging environment, it will be swapped to the production environment. If it is running in the production environment, it will be swapped to staging. service_name: Name of the hosted service. production: The name of the production deployment. source_deployment: The name of the source deployment. ''' _validate_not_none('service_name', service_name) _validate_not_none('production', production) _validate_not_none('source_deployment', source_deployment) return self._perform_post(self._get_hosted_service_path(service_name), _XmlSerializer.swap_deployment_to_xml( production, source_deployment), async=True) def change_deployment_configuration(self, service_name, deployment_name, configuration, treat_warnings_as_error=False, mode='Auto', extended_properties=None): ''' Initiates a change to the deployment configuration. service_name: Name of the hosted service. deployment_name: The name of the deployment. configuration: The base-64 encoded service configuration file for the deployment. treat_warnings_as_error: Indicates whether to treat package validation warnings as errors. If set to true, the Created Deployment operation fails if there are validation warnings on the service package. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('configuration', configuration) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=config', _XmlSerializer.change_deployment_to_xml( configuration, treat_warnings_as_error, mode, extended_properties), async=True) def update_deployment_status(self, service_name, deployment_name, status): ''' Initiates a change in deployment status. service_name: Name of the hosted service. deployment_name: The name of the deployment. status: The change to initiate to the deployment status. Possible values include: Running, Suspended ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('status', status) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=status', _XmlSerializer.update_deployment_status_to_xml( status), async=True) def upgrade_deployment(self, service_name, deployment_name, mode, package_url, configuration, label, force, role_to_upgrade=None, extended_properties=None): ''' Initiates an upgrade. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: If set to Manual, WalkUpgradeDomain must be called to apply the update. If set to Auto, the Windows Azure platform will automatically apply the update To each upgrade domain for the service. Possible values are: Auto, Manual package_url: A URL that refers to the location of the service package in the Blob service. The service package can be located either in a storage account beneath the same subscription or a Shared Access Signature (SAS) URI from any storage account. configuration: The base-64 encoded service configuration file for the deployment. label: A name for the hosted service. The name can be up to 100 characters in length. It is recommended that the label be unique within the subscription. The name can be used to identify the hosted service for your tracking purposes. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. role_to_upgrade: The name of the specific role to upgrade. extended_properties: Dictionary containing name/value pairs of storage account properties. You can have a maximum of 50 extended property name/value pairs. The maximum length of the Name element is 64 characters, only alphanumeric characters and underscores are valid in the Name, and the name must start with a letter. The value has a maximum length of 255 characters. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('package_url', package_url) _validate_not_none('configuration', configuration) _validate_not_none('label', label) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=upgrade', _XmlSerializer.upgrade_deployment_to_xml( mode, package_url, configuration, label, role_to_upgrade, force, extended_properties), async=True) def walk_upgrade_domain(self, service_name, deployment_name, upgrade_domain): ''' Specifies the next upgrade domain to be walked during manual in-place upgrade or configuration change. service_name: Name of the hosted service. deployment_name: The name of the deployment. upgrade_domain: An integer value that identifies the upgrade domain to walk. Upgrade domains are identified with a zero-based index: the first upgrade domain has an ID of 0, the second has an ID of 1, and so on. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('upgrade_domain', upgrade_domain) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=walkupgradedomain', _XmlSerializer.walk_upgrade_domain_to_xml( upgrade_domain), async=True) def rollback_update_or_upgrade(self, service_name, deployment_name, mode, force): ''' Cancels an in progress configuration change (update) or upgrade and returns the deployment to its state before the upgrade or configuration change was started. service_name: Name of the hosted service. deployment_name: The name of the deployment. mode: Specifies whether the rollback should proceed automatically. auto - The rollback proceeds without further user input. manual - You must call the Walk Upgrade Domain operation to apply the rollback to each upgrade domain. force: Specifies whether the rollback should proceed even when it will cause local data to be lost from some role instances. True if the rollback should proceed; otherwise false if the rollback should fail. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('mode', mode) _validate_not_none('force', force) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + '/?comp=rollback', _XmlSerializer.rollback_upgrade_to_xml( mode, force), async=True) def reboot_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reboot of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reboot', '', async=True) def reimage_role_instance(self, service_name, deployment_name, role_instance_name): ''' Requests a reimage of a role instance that is running in a deployment. service_name: Name of the hosted service. deployment_name: The name of the deployment. role_instance_name: The name of the role instance. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_instance_name', role_instance_name) return self._perform_post( self._get_deployment_path_using_name( service_name, deployment_name) + \ '/roleinstances/' + _str(role_instance_name) + \ '?comp=reimage', '', async=True) def check_hosted_service_name_availability(self, service_name): ''' Checks to see if the specified hosted service name is available, or if it has already been taken. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/operations/isavailable/' + _str(service_name) + '', AvailabilityResponse) #--Operations for service certificates ------------------------------- def list_service_certificates(self, service_name): ''' Lists all of the service certificates associated with the specified hosted service. service_name: Name of the hosted service. ''' _validate_not_none('service_name', service_name) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', Certificates) def get_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Returns the public data for the specified X.509 certificate associated with a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint) + '', Certificate) def add_service_certificate(self, service_name, data, certificate_format, password): ''' Adds a certificate to a hosted service. service_name: Name of the hosted service. data: The base-64 encoded form of the pfx file. certificate_format: The service certificate format. The only supported value is pfx. password: The certificate password. ''' _validate_not_none('service_name', service_name) _validate_not_none('data', data) _validate_not_none('certificate_format', certificate_format) _validate_not_none('password', password) return self._perform_post( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates', _XmlSerializer.certificate_file_to_xml( data, certificate_format, password), async=True) def delete_service_certificate(self, service_name, thumbalgorithm, thumbprint): ''' Deletes a service certificate from the certificate store of a hosted service. service_name: Name of the hosted service. thumbalgorithm: The algorithm for the certificate's thumbprint. thumbprint: The hexadecimal representation of the thumbprint. ''' _validate_not_none('service_name', service_name) _validate_not_none('thumbalgorithm', thumbalgorithm) _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/services/hostedservices/' + _str(service_name) + '/certificates/' + _str(thumbalgorithm) + '-' + _str(thumbprint), async=True) #--Operations for management certificates ---------------------------- def list_management_certificates(self): ''' The List Management Certificates operation lists and returns basic information about all of the management certificates associated with the specified subscription. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. ''' return self._perform_get('/' + self.subscription_id + '/certificates', SubscriptionCertificates) def get_management_certificate(self, thumbprint): ''' The Get Management Certificate operation retrieves information about the management certificate with the specified thumbprint. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumbprint value of the certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_get( '/' + self.subscription_id + '/certificates/' + _str(thumbprint), SubscriptionCertificate) def add_management_certificate(self, public_key, thumbprint, data): ''' The Add Management Certificate operation adds a certificate to the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. public_key: A base64 representation of the management certificate public key. thumbprint: The thumb print that uniquely identifies the management certificate. data: The certificate's raw data in base-64 encoded .cer format. ''' _validate_not_none('public_key', public_key) _validate_not_none('thumbprint', thumbprint) _validate_not_none('data', data) return self._perform_post( '/' + self.subscription_id + '/certificates', _XmlSerializer.subscription_certificate_to_xml( public_key, thumbprint, data)) def delete_management_certificate(self, thumbprint): ''' The Delete Management Certificate operation deletes a certificate from the list of management certificates. Management certificates, which are also known as subscription certificates, authenticate clients attempting to connect to resources associated with your Windows Azure subscription. thumbprint: The thumb print that uniquely identifies the management certificate. ''' _validate_not_none('thumbprint', thumbprint) return self._perform_delete( '/' + self.subscription_id + '/certificates/' + _str(thumbprint)) #--Operations for affinity groups ------------------------------------ def list_affinity_groups(self): ''' Lists the affinity groups associated with the specified subscription. ''' return self._perform_get( '/' + self.subscription_id + '/affinitygroups', AffinityGroups) def get_affinity_group_properties(self, affinity_group_name): ''' Returns the system properties associated with the specified affinity group. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_get( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name) + '', AffinityGroup) def create_affinity_group(self, name, label, location, description=None): ''' Creates a new affinity group for the specified subscription. name: A name for the affinity group that is unique to the subscription. label: A name for the affinity group. The name can be up to 100 characters in length. location: The data center location where the affinity group will be created. To list available locations, use the list_location function. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('name', name) _validate_not_none('label', label) _validate_not_none('location', location) return self._perform_post( '/' + self.subscription_id + '/affinitygroups', _XmlSerializer.create_affinity_group_to_xml(name, label, description, location)) def update_affinity_group(self, affinity_group_name, label, description=None): ''' Updates the label and/or the description for an affinity group for the specified subscription. affinity_group_name: The name of the affinity group. label: A name for the affinity group. The name can be up to 100 characters in length. description: A description for the affinity group. The description can be up to 1024 characters in length. ''' _validate_not_none('affinity_group_name', affinity_group_name) _validate_not_none('label', label) return self._perform_put( '/' + self.subscription_id + '/affinitygroups/' + _str(affinity_group_name), _XmlSerializer.update_affinity_group_to_xml(label, description)) def delete_affinity_group(self, affinity_group_name): ''' Deletes an affinity group in the specified subscription. affinity_group_name: The name of the affinity group. ''' _validate_not_none('affinity_group_name', affinity_group_name) return self._perform_delete('/' + self.subscription_id + \ '/affinitygroups/' + \ _str(affinity_group_name)) #--Operations for locations ------------------------------------------ def list_locations(self): ''' Lists all of the data center locations that are valid for your subscription. ''' return self._perform_get('/' + self.subscription_id + '/locations', Locations) #--Operations for tracking asynchronous requests --------------------- def get_operation_status(self, request_id): ''' Returns the status of the specified operation. After calling an asynchronous operation, you can call Get Operation Status to determine whether the operation has succeeded, failed, or is still in progress. request_id: The request ID for the request you wish to track. ''' _validate_not_none('request_id', request_id) return self._perform_get( '/' + self.subscription_id + '/operations/' + _str(request_id), Operation) #--Operations for retrieving operating system information ------------ def list_operating_systems(self): ''' Lists the versions of the guest operating system that are currently available in Windows Azure. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystems', OperatingSystems) def list_operating_system_families(self): ''' Lists the guest operating system families available in Windows Azure, and also lists the operating system versions available for each family. ''' return self._perform_get( '/' + self.subscription_id + '/operatingsystemfamilies', OperatingSystemFamilies) #--Operations for retrieving subscription history -------------------- def get_subscription(self): ''' Returns account and resource allocation information on the specified subscription. ''' return self._perform_get('/' + self.subscription_id + '', Subscription) #--Operations for virtual machines ----------------------------------- def get_role(self, service_name, deployment_name, role_name): ''' Retrieves the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_get( self._get_role_path(service_name, deployment_name, role_name), PersistentVMRole) def create_virtual_machine_deployment(self, service_name, deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole', virtual_network_name=None): ''' Provisions a virtual machine based on the supplied configuration. service_name: Name of the hosted service. deployment_name: The name for the deployment. The deployment name must be unique among other deployments for the hosted service. deployment_slot: The environment to which the hosted service is deployed. Valid values are: staging, production label: Specifies an identifier for the deployment. The label can be up to 100 characters long. The label can be used for tracking purposes. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. virtual_network_name: Specifies the name of an existing virtual network to which the deployment will belong. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_slot', deployment_slot) _validate_not_none('label', label) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_deployment_path_using_name(service_name), _XmlSerializer.virtual_machine_deployment_to_xml( deployment_name, deployment_slot, label, role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size, virtual_network_name), async=True) def add_role(self, service_name, deployment_name, role_name, system_config, os_virtual_hard_disk, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Adds a virtual machine to an existing deployment. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. system_config: Contains the metadata required to provision a virtual machine from a Windows or Linux OS image. Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('system_config', system_config) _validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk) return self._perform_post( self._get_role_path(service_name, deployment_name), _XmlSerializer.add_role_to_xml( role_name, system_config, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def update_role(self, service_name, deployment_name, role_name, os_virtual_hard_disk=None, network_config=None, availability_set_name=None, data_virtual_hard_disks=None, role_size=None, role_type='PersistentVMRole'): ''' Updates the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. os_virtual_hard_disk: Contains the parameters Windows Azure uses to create the operating system disk for the virtual machine. network_config: Encapsulates the metadata required to create the virtual network configuration for a virtual machine. If you do not include a network configuration set you will not be able to access the VM through VIPs over the internet. If your virtual machine belongs to a virtual network you can not specify which subnet address space it resides under. availability_set_name: Specifies the name of an availability set to which to add the virtual machine. This value controls the virtual machine allocation in the Windows Azure environment. Virtual machines specified in the same availability set are allocated to different nodes to maximize availability. data_virtual_hard_disks: Contains the parameters Windows Azure uses to create a data disk for a virtual machine. role_size: The size of the virtual machine to allocate. The default value is Small. Possible values are: ExtraSmall, Small, Medium, Large, ExtraLarge. The specified value must be compatible with the disk selected in the OSVirtualHardDisk values. role_type: The type of the role for the virtual machine. The only supported value is PersistentVMRole. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_put( self._get_role_path(service_name, deployment_name, role_name), _XmlSerializer.update_role_to_xml( role_name, os_virtual_hard_disk, role_type, network_config, availability_set_name, data_virtual_hard_disks, role_size), async=True) def delete_role(self, service_name, deployment_name, role_name): ''' Deletes the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_delete( self._get_role_path(service_name, deployment_name, role_name), async=True) def capture_role(self, service_name, deployment_name, role_name, post_capture_action, target_image_name, target_image_label, provisioning_configuration=None): ''' The Capture Role operation captures a virtual machine image to your image gallery. From the captured image, you can create additional customized virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_capture_action: Specifies the action after capture operation completes. Possible values are: Delete, Reprovision. target_image_name: Specifies the image name of the captured virtual machine. target_image_label: Specifies the friendly name of the captured virtual machine. provisioning_configuration: Use an instance of WindowsConfigurationSet or LinuxConfigurationSet. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_capture_action', post_capture_action) _validate_not_none('target_image_name', target_image_name) _validate_not_none('target_image_label', target_image_label) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.capture_role_to_xml( post_capture_action, target_image_name, target_image_label, provisioning_configuration), async=True) def start_role(self, service_name, deployment_name, role_name): ''' Starts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.start_role_operation_to_xml(), async=True) def start_roles(self, service_name, deployment_name, role_names): ''' Starts the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.start_roles_operation_to_xml(role_names), async=True) def restart_role(self, service_name, deployment_name, role_name): ''' Restarts the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.restart_role_operation_to_xml( ), async=True) def shutdown_role(self, service_name, deployment_name, role_name, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_role_instance_operations_path( service_name, deployment_name, role_name), _XmlSerializer.shutdown_role_operation_to_xml(post_shutdown_action), async=True) def shutdown_roles(self, service_name, deployment_name, role_names, post_shutdown_action='Stopped'): ''' Shuts down the specified virtual machines. service_name: The name of the service. deployment_name: The name of the deployment. role_names: The names of the roles, as an enumerable of strings. post_shutdown_action: Specifies how the Virtual Machine should be shut down. Values are: Stopped Shuts down the Virtual Machine but retains the compute resources. You will continue to be billed for the resources that the stopped machine uses. StoppedDeallocated Shuts down the Virtual Machine and releases the compute resources. You are not billed for the compute resources that this Virtual Machine uses. If a static Virtual Network IP address is assigned to the Virtual Machine, it is reserved. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_names', role_names) _validate_not_none('post_shutdown_action', post_shutdown_action) return self._perform_post( self._get_roles_operations_path(service_name, deployment_name), _XmlSerializer.shutdown_roles_operation_to_xml( role_names, post_shutdown_action), async=True) #--Operations for virtual machine images ----------------------------- def list_os_images(self): ''' Retrieves a list of the OS images from the image repository. ''' return self._perform_get(self._get_image_path(), Images) def get_os_image(self, image_name): ''' Retrieves an OS image from the image repository. ''' return self._perform_get(self._get_image_path(image_name), OSImage) def add_os_image(self, label, media_link, name, os): ''' Adds an OS image that is currently stored in a storage account in your subscription to the image repository. label: Specifies the friendly name of the image. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more virtual machines. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_image_path(), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def update_os_image(self, image_name, label, media_link, name, os): ''' Updates an OS image that in your image repository. image_name: The name of the image to update. label: Specifies the friendly name of the image to be updated. You cannot use this operation to update images provided by the Windows Azure platform. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the image is located. The blob location must belong to a storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the OS image that Windows Azure uses to identify the image when creating one or more VM Roles. os: The operating system type of the OS image. Possible values are: Linux, Windows ''' _validate_not_none('image_name', image_name) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_image_path(image_name), _XmlSerializer.os_image_to_xml( label, media_link, name, os), async=True) def delete_os_image(self, image_name, delete_vhd=False): ''' Deletes the specified OS image from your image repository. image_name: The name of the image. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('image_name', image_name) path = self._get_image_path(image_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def get_data_disk(self, service_name, deployment_name, role_name, lun): ''' Retrieves the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_get( self._get_data_disk_path( service_name, deployment_name, role_name, lun), DataVirtualHardDisk) def add_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None, source_media_link=None): ''' Adds a data disk to a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. source_media_link: Specifies the location of a blob in account storage which is mounted as a data disk when the virtual machine is created. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_post( self._get_data_disk_path(service_name, deployment_name, role_name), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, lun, logical_disk_size_in_gb, media_link, source_media_link), async=True) def update_data_disk(self, service_name, deployment_name, role_name, lun, host_caching=None, media_link=None, updated_lun=None, disk_label=None, disk_name=None, logical_disk_size_in_gb=None): ''' Updates the specified data disk attached to the specified virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. host_caching: Specifies the platform caching behavior of data disk blob for read/write efficiency. The default vault is ReadOnly. Possible values are: None, ReadOnly, ReadWrite media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd updated_lun: Specifies the Logical Unit Number (LUN) for the disk. The LUN specifies the slot in which the data drive appears when mounted for usage by the virtual machine. Valid LUN values are 0 through 15. disk_label: Specifies the description of the data disk. When you attach a disk, either by directly referencing a media using the MediaLink element or specifying the target disk size, you can use the DiskLabel element to customize the name property of the target data disk. disk_name: Specifies the name of the disk. Windows Azure uses the specified disk to create the data disk for the machine and populates this field with the disk name. logical_disk_size_in_gb: Specifies the size, in GB, of an empty disk to be attached to the role. The disk can be created as part of disk attach or create VM role call by specifying the value for this property. Windows Azure creates the empty disk based on size preference and attaches the newly created disk to the Role. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) return self._perform_put( self._get_data_disk_path( service_name, deployment_name, role_name, lun), _XmlSerializer.data_virtual_hard_disk_to_xml( host_caching, disk_label, disk_name, updated_lun, logical_disk_size_in_gb, media_link, None), async=True) def delete_data_disk(self, service_name, deployment_name, role_name, lun, delete_vhd=False): ''' Removes the specified data disk from a virtual machine. service_name: The name of the service. deployment_name: The name of the deployment. role_name: The name of the role. lun: The Logical Unit Number (LUN) for the disk. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('service_name', service_name) _validate_not_none('deployment_name', deployment_name) _validate_not_none('role_name', role_name) _validate_not_none('lun', lun) path = self._get_data_disk_path(service_name, deployment_name, role_name, lun) if delete_vhd: path += '?comp=media' return self._perform_delete(path, async=True) #--Operations for virtual machine disks ------------------------------ def list_disks(self): ''' Retrieves a list of the disks in your image repository. ''' return self._perform_get(self._get_disk_path(), Disks) def get_disk(self, disk_name): ''' Retrieves a disk from your image repository. ''' return self._perform_get(self._get_disk_path(disk_name), Disk) def add_disk(self, has_operating_system, label, media_link, name, os): ''' Adds a disk to the user image repository. The disk can be an OS disk or a data disk. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_post(self._get_disk_path(), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def update_disk(self, disk_name, has_operating_system, label, media_link, name, os): ''' Updates an existing disk in your image repository. disk_name: The name of the disk to update. has_operating_system: Specifies whether the disk contains an operation system. Only a disk with an operating system installed can be mounted as OS Drive. label: Specifies the description of the disk. media_link: Specifies the location of the blob in Windows Azure blob store where the media for the disk is located. The blob location must belong to the storage account in the current subscription specified by the value in the operation call. Example: http://example.blob.core.windows.net/disks/mydisk.vhd name: Specifies a name for the disk. Windows Azure uses the name to identify the disk when creating virtual machines from the disk. os: The OS type of the disk. Possible values are: Linux, Windows ''' _validate_not_none('disk_name', disk_name) _validate_not_none('has_operating_system', has_operating_system) _validate_not_none('label', label) _validate_not_none('media_link', media_link) _validate_not_none('name', name) _validate_not_none('os', os) return self._perform_put(self._get_disk_path(disk_name), _XmlSerializer.disk_to_xml( has_operating_system, label, media_link, name, os)) def delete_disk(self, disk_name, delete_vhd=False): ''' Deletes the specified data or operating system disk from your image repository. disk_name: The name of the disk to delete. delete_vhd: Deletes the underlying vhd blob in Azure storage. ''' _validate_not_none('disk_name', disk_name) path = self._get_disk_path(disk_name) if delete_vhd: path += '?comp=media' return self._perform_delete(path) #--Operations for virtual networks ------------------------------ def list_virtual_network_sites(self): ''' Retrieves a list of the virtual networks. ''' return self._perform_get(self._get_virtual_network_site_path(), VirtualNetworkSites) #--Helper functions -------------------------------------------------- def _get_virtual_network_site_path(self): return self._get_path('services/networking/virtualnetwork', None) def _get_storage_service_path(self, service_name=None): return self._get_path('services/storageservices', service_name) def _get_hosted_service_path(self, service_name=None): return self._get_path('services/hostedservices', service_name) def _get_deployment_path_using_slot(self, service_name, slot=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deploymentslots', slot) def _get_deployment_path_using_name(self, service_name, deployment_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments', deployment_name) def _get_role_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles', role_name) def _get_role_instance_operations_path(self, service_name, deployment_name, role_name=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roleinstances', role_name) + '/Operations' def _get_roles_operations_path(self, service_name, deployment_name): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + deployment_name + '/roles/Operations', None) def _get_data_disk_path(self, service_name, deployment_name, role_name, lun=None): return self._get_path('services/hostedservices/' + _str(service_name) + '/deployments/' + _str(deployment_name) + '/roles/' + _str(role_name) + '/DataDisks', lun) def _get_disk_path(self, disk_name=None): return self._get_path('services/disks', disk_name) def _get_image_path(self, image_name=None): return self._get_path('services/images', image_name) ================================================ FILE: OSPatching/azure/servicemanagement/sqldatabasemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _parse_service_resources_response, ) from azure.servicemanagement import ( Servers, Database, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class SqlDatabaseManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on SQL Database management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(SqlDatabaseManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for sql servers ---------------------------------------- def list_servers(self): ''' List the SQL servers defined on the account. ''' return self._perform_get(self._get_list_servers_path(), Servers) #--Operations for sql databases ---------------------------------------- def list_databases(self, name): ''' List the SQL databases defined on the specified server name ''' response = self._perform_get(self._get_list_databases_path(name), None) return _parse_service_resources_response(response, Database) #--Helper functions -------------------------------------------------- def _get_list_servers_path(self): return self._get_path('services/sqlservers/servers', None) def _get_list_databases_path(self, name): # *contentview=generic is mandatory* return self._get_path('services/sqlservers/servers/', name) + '/databases?contentview=generic' ================================================ FILE: OSPatching/azure/servicemanagement/websitemanagementservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( MANAGEMENT_HOST, _str, ) from azure.servicemanagement import ( WebSpaces, WebSpace, Sites, Site, MetricResponses, MetricDefinitions, PublishData, _XmlSerializer, ) from azure.servicemanagement.servicemanagementclient import ( _ServiceManagementClient, ) class WebsiteManagementService(_ServiceManagementClient): ''' Note that this class is a preliminary work on WebSite management. Since it lack a lot a features, final version can be slightly different from the current one. ''' def __init__(self, subscription_id=None, cert_file=None, host=MANAGEMENT_HOST): super(WebsiteManagementService, self).__init__( subscription_id, cert_file, host) #--Operations for web sites ---------------------------------------- def list_webspaces(self): ''' List the webspaces defined on the account. ''' return self._perform_get(self._get_list_webspaces_path(), WebSpaces) def get_webspace(self, webspace_name): ''' Get details of a specific webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_webspace_details_path(webspace_name), WebSpace) def list_sites(self, webspace_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. ''' return self._perform_get(self._get_sites_path(webspace_name), Sites) def get_site(self, webspace_name, website_name): ''' List the web sites defined on this webspace. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_sites_details_path(webspace_name, website_name), Site) def create_site(self, webspace_name, website_name, geo_region, host_names, plan='VirtualDedicatedPlan', compute_mode='Shared', server_farm=None, site_mode=None): ''' Create a website. webspace_name: The name of the webspace. website_name: The name of the website. geo_region: The geographical region of the webspace that will be created. host_names: An array of fully qualified domain names for website. Only one hostname can be specified in the azurewebsites.net domain. The hostname should match the name of the website. Custom domains can only be specified for Shared or Standard websites. plan: This value must be 'VirtualDedicatedPlan'. compute_mode: This value should be 'Shared' for the Free or Paid Shared offerings, or 'Dedicated' for the Standard offering. The default value is 'Shared'. If you set it to 'Dedicated', you must specify a value for the server_farm parameter. server_farm: The name of the Server Farm associated with this website. This is a required value for Standard mode. site_mode: Can be None, 'Limited' or 'Basic'. This value is 'Limited' for the Free offering, and 'Basic' for the Paid Shared offering. Standard mode does not use the site_mode parameter; it uses the compute_mode parameter. ''' xml = _XmlSerializer.create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode) return self._perform_post( self._get_sites_path(webspace_name), xml, Site) def delete_site(self, webspace_name, website_name, delete_empty_server_farm=False, delete_metrics=False): ''' Delete a website. webspace_name: The name of the webspace. website_name: The name of the website. delete_empty_server_farm: If the site being deleted is the last web site in a server farm, you can delete the server farm by setting this to True. delete_metrics: To also delete the metrics for the site that you are deleting, you can set this to True. ''' path = self._get_sites_details_path(webspace_name, website_name) query = '' if delete_empty_server_farm: query += '&deleteEmptyServerFarm=true' if delete_metrics: query += '&deleteMetrics=true' if query: path = path + '?' + query.lstrip('&') return self._perform_delete(path) def restart_site(self, webspace_name, website_name): ''' Restart a web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_post( self._get_restart_path(webspace_name, website_name), '') def get_historical_usage_metrics(self, webspace_name, website_name, metrics = None, start_time=None, end_time=None, time_grain=None): ''' Get historical usage metrics. webspace_name: The name of the webspace. website_name: The name of the website. metrics: Optional. List of metrics name. Otherwise, all metrics returned. start_time: Optional. An ISO8601 date. Otherwise, current hour is used. end_time: Optional. An ISO8601 date. Otherwise, current time is used. time_grain: Optional. A rollup name, as P1D. OTherwise, default rollup for the metrics is used. More information and metrics name at: http://msdn.microsoft.com/en-us/library/azure/dn166964.aspx ''' metrics = ('names='+','.join(metrics)) if metrics else '' start_time = ('StartTime='+start_time) if start_time else '' end_time = ('EndTime='+end_time) if end_time else '' time_grain = ('TimeGrain='+time_grain) if time_grain else '' parameters = ('&'.join(v for v in (metrics, start_time, end_time, time_grain) if v)) parameters = '?'+parameters if parameters else '' return self._perform_get(self._get_historical_usage_metrics_path(webspace_name, website_name) + parameters, MetricResponses) def get_metric_definitions(self, webspace_name, website_name): ''' Get metric definitions of metrics available of this web site. webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_metric_definitions_path(webspace_name, website_name), MetricDefinitions) def get_publish_profile_xml(self, webspace_name, website_name): ''' Get a site's publish profile as a string webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), None).body.decode("utf-8") def get_publish_profile(self, webspace_name, website_name): ''' Get a site's publish profile as an object webspace_name: The name of the webspace. website_name: The name of the website. ''' return self._perform_get(self._get_publishxml_path(webspace_name, website_name), PublishData) #--Helper functions -------------------------------------------------- def _get_list_webspaces_path(self): return self._get_path('services/webspaces', None) def _get_webspace_details_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) def _get_sites_path(self, webspace_name): return self._get_path('services/webspaces/', webspace_name) + '/sites' def _get_sites_details_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) def _get_restart_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/restart/' def _get_historical_usage_metrics_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metrics/' def _get_metric_definitions_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/metricdefinitions/' def _get_publishxml_path(self, webspace_name, website_name): return self._get_path('services/webspaces/', webspace_name) + '/sites/' + _str(website_name) + '/publishxml/' ================================================ FILE: OSPatching/azure/storage/__init__.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import sys import types from datetime import datetime from xml.dom import minidom from azure import (WindowsAzureData, WindowsAzureError, METADATA_NS, xml_escape, _create_entry, _decode_base64_to_text, _decode_base64_to_bytes, _encode_base64, _fill_data_minidom, _fill_instance_element, _get_child_nodes, _get_child_nodesNS, _get_children_from_path, _get_entry_properties, _general_error_handler, _list_of, _parse_response_for_dict, _sign_string, _unicode_type, _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY, ) # x-ms-version for storage service. X_MS_VERSION = '2012-02-12' class EnumResultsBase(object): ''' base class for EnumResults. ''' def __init__(self): self.prefix = u'' self.marker = u'' self.max_results = 0 self.next_marker = u'' class ContainerEnumResults(EnumResultsBase): ''' Blob Container list. ''' def __init__(self): EnumResultsBase.__init__(self) self.containers = _list_of(Container) def __iter__(self): return iter(self.containers) def __len__(self): return len(self.containers) def __getitem__(self, index): return self.containers[index] class Container(WindowsAzureData): ''' Blob container class. ''' def __init__(self): self.name = u'' self.url = u'' self.properties = Properties() self.metadata = {} class Properties(WindowsAzureData): ''' Blob container's properties class. ''' def __init__(self): self.last_modified = u'' self.etag = u'' class RetentionPolicy(WindowsAzureData): ''' RetentionPolicy in service properties. ''' def __init__(self): self.enabled = False self.__dict__['days'] = None def get_days(self): # convert days to int value return int(self.__dict__['days']) def set_days(self, value): ''' set default days if days is set to empty. ''' self.__dict__['days'] = value days = property(fget=get_days, fset=set_days) class Logging(WindowsAzureData): ''' Logging class in service properties. ''' def __init__(self): self.version = u'1.0' self.delete = False self.read = False self.write = False self.retention_policy = RetentionPolicy() class Metrics(WindowsAzureData): ''' Metrics class in service properties. ''' def __init__(self): self.version = u'1.0' self.enabled = False self.include_apis = None self.retention_policy = RetentionPolicy() class StorageServiceProperties(WindowsAzureData): ''' Storage Service Propeties class. ''' def __init__(self): self.logging = Logging() self.metrics = Metrics() class AccessPolicy(WindowsAzureData): ''' Access Policy class in service properties. ''' def __init__(self, start=u'', expiry=u'', permission='u'): self.start = start self.expiry = expiry self.permission = permission class SignedIdentifier(WindowsAzureData): ''' Signed Identifier class for service properties. ''' def __init__(self): self.id = u'' self.access_policy = AccessPolicy() class SignedIdentifiers(WindowsAzureData): ''' SignedIdentifier list. ''' def __init__(self): self.signed_identifiers = _list_of(SignedIdentifier) def __iter__(self): return iter(self.signed_identifiers) def __len__(self): return len(self.signed_identifiers) def __getitem__(self, index): return self.signed_identifiers[index] class BlobEnumResults(EnumResultsBase): ''' Blob list.''' def __init__(self): EnumResultsBase.__init__(self) self.blobs = _list_of(Blob) self.prefixes = _list_of(BlobPrefix) self.delimiter = '' def __iter__(self): return iter(self.blobs) def __len__(self): return len(self.blobs) def __getitem__(self, index): return self.blobs[index] class BlobResult(bytes): def __new__(cls, blob, properties): return bytes.__new__(cls, blob if blob else b'') def __init__(self, blob, properties): self.properties = properties class Blob(WindowsAzureData): ''' Blob class. ''' def __init__(self): self.name = u'' self.snapshot = u'' self.url = u'' self.properties = BlobProperties() self.metadata = {} class BlobProperties(WindowsAzureData): ''' Blob Properties ''' def __init__(self): self.last_modified = u'' self.etag = u'' self.content_length = 0 self.content_type = u'' self.content_encoding = u'' self.content_language = u'' self.content_md5 = u'' self.xms_blob_sequence_number = 0 self.blob_type = u'' self.lease_status = u'' self.lease_state = u'' self.lease_duration = u'' self.copy_id = u'' self.copy_source = u'' self.copy_status = u'' self.copy_progress = u'' self.copy_completion_time = u'' self.copy_status_description = u'' class BlobPrefix(WindowsAzureData): ''' BlobPrefix in Blob. ''' def __init__(self): self.name = '' class BlobBlock(WindowsAzureData): ''' BlobBlock class ''' def __init__(self, id=None, size=None): self.id = id self.size = size class BlobBlockList(WindowsAzureData): ''' BlobBlockList class ''' def __init__(self): self.committed_blocks = [] self.uncommitted_blocks = [] class PageRange(WindowsAzureData): ''' Page Range for page blob. ''' def __init__(self): self.start = 0 self.end = 0 class PageList(object): ''' Page list for page blob. ''' def __init__(self): self.page_ranges = _list_of(PageRange) def __iter__(self): return iter(self.page_ranges) def __len__(self): return len(self.page_ranges) def __getitem__(self, index): return self.page_ranges[index] class QueueEnumResults(EnumResultsBase): ''' Queue list''' def __init__(self): EnumResultsBase.__init__(self) self.queues = _list_of(Queue) def __iter__(self): return iter(self.queues) def __len__(self): return len(self.queues) def __getitem__(self, index): return self.queues[index] class Queue(WindowsAzureData): ''' Queue class ''' def __init__(self): self.name = u'' self.url = u'' self.metadata = {} class QueueMessagesList(WindowsAzureData): ''' Queue message list. ''' def __init__(self): self.queue_messages = _list_of(QueueMessage) def __iter__(self): return iter(self.queue_messages) def __len__(self): return len(self.queue_messages) def __getitem__(self, index): return self.queue_messages[index] class QueueMessage(WindowsAzureData): ''' Queue message class. ''' def __init__(self): self.message_id = u'' self.insertion_time = u'' self.expiration_time = u'' self.pop_receipt = u'' self.time_next_visible = u'' self.dequeue_count = u'' self.message_text = u'' class Entity(WindowsAzureData): ''' Entity class. The attributes of entity will be created dynamically. ''' pass class EntityProperty(WindowsAzureData): ''' Entity property. contains type and value. ''' def __init__(self, type=None, value=None): self.type = type self.value = value class Table(WindowsAzureData): ''' Only for intellicens and telling user the return type. ''' pass def _parse_blob_enum_results_list(response): respbody = response.body return_obj = BlobEnumResults() doc = minidom.parseString(respbody) for enum_results in _get_child_nodes(doc, 'EnumerationResults'): for child in _get_children_from_path(enum_results, 'Blobs', 'Blob'): return_obj.blobs.append(_fill_instance_element(child, Blob)) for child in _get_children_from_path(enum_results, 'Blobs', 'BlobPrefix'): return_obj.prefixes.append( _fill_instance_element(child, BlobPrefix)) for name, value in vars(return_obj).items(): if name == 'blobs' or name == 'prefixes': continue value = _fill_data_minidom(enum_results, name, value) if value is not None: setattr(return_obj, name, value) return return_obj def _update_storage_header(request): ''' add additional headers for storage request. ''' if request.body: assert isinstance(request.body, bytes) # if it is PUT, POST, MERGE, DELETE, need to add content-lengt to header. if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: request.headers.append(('Content-Length', str(len(request.body)))) # append addtional headers base on the service request.headers.append(('x-ms-version', X_MS_VERSION)) # append x-ms-meta name, values to header for name, value in request.headers: if 'x-ms-meta-name-values' in name and value: for meta_name, meta_value in value.items(): request.headers.append(('x-ms-meta-' + meta_name, meta_value)) request.headers.remove((name, value)) break return request def _update_storage_blob_header(request, account_name, account_key): ''' add additional headers for storage blob request. ''' request = _update_storage_header(request) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append( ('Content-Type', 'application/octet-stream Charset=UTF-8')) request.headers.append(('Authorization', _sign_storage_blob_request(request, account_name, account_key))) return request.headers def _update_storage_queue_header(request, account_name, account_key): ''' add additional headers for storage queue request. ''' return _update_storage_blob_header(request, account_name, account_key) def _update_storage_table_header(request): ''' add additional headers for storage table request. ''' request = _update_storage_header(request) for name, _ in request.headers: if name.lower() == 'content-type': break else: request.headers.append(('Content-Type', 'application/atom+xml')) request.headers.append(('DataServiceVersion', '2.0;NetFx')) request.headers.append(('MaxDataServiceVersion', '2.0;NetFx')) current_time = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') request.headers.append(('x-ms-date', current_time)) request.headers.append(('Date', current_time)) return request.headers def _sign_storage_blob_request(request, account_name, account_key): ''' Returns the signed string for blob request which is used to set Authorization header. This is also used to sign queue request. ''' uri_path = request.path.split('?')[0] # method to sign string_to_sign = request.method + '\n' # get headers to sign headers_to_sign = [ 'content-encoding', 'content-language', 'content-length', 'content-md5', 'content-type', 'date', 'if-modified-since', 'if-match', 'if-none-match', 'if-unmodified-since', 'range'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get x-ms header to sign x_ms_headers = [] for name, value in request.headers: if 'x-ms' in name: x_ms_headers.append((name.lower(), value)) x_ms_headers.sort() for name, value in x_ms_headers: if value: string_to_sign += ''.join([name, ':', value, '\n']) # get account_name and uri path to sign string_to_sign += '/' + account_name + uri_path # get query string to sign if it is not table service query_to_sign = request.query query_to_sign.sort() current_name = '' for name, value in query_to_sign: if value: if current_name != name: string_to_sign += '\n' + name + ':' + value else: string_to_sign += '\n' + ',' + value # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _sign_storage_table_request(request, account_name, account_key): uri_path = request.path.split('?')[0] string_to_sign = request.method + '\n' headers_to_sign = ['content-md5', 'content-type', 'date'] request_header_dict = dict((name.lower(), value) for name, value in request.headers if value) string_to_sign += '\n'.join(request_header_dict.get(x, '') for x in headers_to_sign) + '\n' # get account_name and uri path to sign string_to_sign += ''.join(['/', account_name, uri_path]) for name, value in request.query: if name == 'comp' and uri_path == '/': string_to_sign += '?comp=' + value break # sign the request auth_string = 'SharedKey ' + account_name + ':' + \ _sign_string(account_key, string_to_sign) return auth_string def _to_python_bool(value): if value.lower() == 'true': return True return False def _to_entity_int(data): int_max = (2 << 30) - 1 if data > (int_max) or data < (int_max + 1) * (-1): return 'Edm.Int64', str(data) else: return 'Edm.Int32', str(data) def _to_entity_bool(value): if value: return 'Edm.Boolean', 'true' return 'Edm.Boolean', 'false' def _to_entity_datetime(value): return 'Edm.DateTime', value.strftime('%Y-%m-%dT%H:%M:%S') def _to_entity_float(value): return 'Edm.Double', str(value) def _to_entity_property(value): if value.type == 'Edm.Binary': return value.type, _encode_base64(value.value) return value.type, str(value.value) def _to_entity_none(value): return None, None def _to_entity_str(value): return 'Edm.String', value # Tables of conversions to and from entity types. We support specific # datatypes, and beyond that the user can use an EntityProperty to get # custom data type support. def _from_entity_binary(value): return EntityProperty('Edm.Binary', _decode_base64_to_bytes(value)) def _from_entity_int(value): return int(value) def _from_entity_datetime(value): format = '%Y-%m-%dT%H:%M:%S' if '.' in value: format = format + '.%f' if value.endswith('Z'): format = format + 'Z' return datetime.strptime(value, format) _ENTITY_TO_PYTHON_CONVERSIONS = { 'Edm.Binary': _from_entity_binary, 'Edm.Int32': _from_entity_int, 'Edm.Int64': _from_entity_int, 'Edm.Double': float, 'Edm.Boolean': _to_python_bool, 'Edm.DateTime': _from_entity_datetime, } # Conversion from Python type to a function which returns a tuple of the # type string and content string. _PYTHON_TO_ENTITY_CONVERSIONS = { int: _to_entity_int, bool: _to_entity_bool, datetime: _to_entity_datetime, float: _to_entity_float, EntityProperty: _to_entity_property, str: _to_entity_str, } if sys.version_info < (3,): _PYTHON_TO_ENTITY_CONVERSIONS.update({ long: _to_entity_int, types.NoneType: _to_entity_none, unicode: _to_entity_str, }) def _convert_entity_to_xml(source): ''' Converts an entity object to xml to send. The entity format is: <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' # construct the entity body included in <m:properties> and </m:properties> entity_body = '<m:properties xml:space="preserve">{properties}</m:properties>' if isinstance(source, WindowsAzureData): source = vars(source) properties_str = '' # set properties type for types we know if value has no type info. # if value has type info, then set the type to value.type for name, value in source.items(): mtype = '' conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(type(value)) if conv is None and sys.version_info >= (3,) and value is None: conv = _to_entity_none if conv is None: raise WindowsAzureError( _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY.format( type(value).__name__)) mtype, value = conv(value) # form the property node properties_str += ''.join(['<d:', name]) if value is None: properties_str += ' m:null="true" />' else: if mtype: properties_str += ''.join([' m:type="', mtype, '"']) properties_str += ''.join(['>', xml_escape(value), '</d:', name, '>']) if sys.version_info < (3,): if isinstance(properties_str, unicode): properties_str = properties_str.encode('utf-8') # generate the entity_body entity_body = entity_body.format(properties=properties_str) xmlstr = _create_entry(entity_body) return xmlstr def _convert_table_to_xml(table_name): ''' Create xml to send for a given table name. Since xml format for table is the same as entity and the only difference is that table has only one property 'TableName', so we just call _convert_entity_to_xml. table_name: the name of the table ''' return _convert_entity_to_xml({'TableName': table_name}) def _convert_block_list_to_xml(block_id_list): ''' Convert a block list to xml to send. block_id_list: a str list containing the block ids that are used in put_block_list. Only get block from latest blocks. ''' if block_id_list is None: return '' xml = '<?xml version="1.0" encoding="utf-8"?><BlockList>' for value in block_id_list: xml += '<Latest>{0}</Latest>'.format(_encode_base64(value)) return xml + '</BlockList>' def _create_blob_result(response): blob_properties = _parse_response_for_dict(response) return BlobResult(response.body, blob_properties) def _convert_response_to_block_list(response): ''' Converts xml response to block list class. ''' blob_block_list = BlobBlockList() xmldoc = minidom.parseString(response.body) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'CommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.committed_blocks.append( BlobBlock(xml_block_id, xml_block_size)) for xml_block in _get_children_from_path(xmldoc, 'BlockList', 'UncommittedBlocks', 'Block'): xml_block_id = _decode_base64_to_text( _get_child_nodes(xml_block, 'Name')[0].firstChild.nodeValue) xml_block_size = int( _get_child_nodes(xml_block, 'Size')[0].firstChild.nodeValue) blob_block_list.uncommitted_blocks.append( BlobBlock(xml_block_id, xml_block_size)) return blob_block_list def _remove_prefix(name): colon = name.find(':') if colon != -1: return name[colon + 1:] return name def _convert_response_to_entity(response): if response is None: return response return _convert_xml_to_entity(response.body) def _convert_xml_to_entity(xmlstr): ''' Convert xml response to entity. The format of entity: <entry xmlns:d="http://schemas.microsoft.com/ado/2007/08/dataservices" xmlns:m="http://schemas.microsoft.com/ado/2007/08/dataservices/metadata" xmlns="http://www.w3.org/2005/Atom"> <title /> <updated>2008-09-18T23:46:19.3857256Z</updated> <author> <name /> </author> <id /> <content type="application/xml"> <m:properties> <d:Address>Mountain View</d:Address> <d:Age m:type="Edm.Int32">23</d:Age> <d:AmountDue m:type="Edm.Double">200.23</d:AmountDue> <d:BinaryData m:type="Edm.Binary" m:null="true" /> <d:CustomerCode m:type="Edm.Guid">c9da6455-213d-42c9-9a79-3e9149a57833</d:CustomerCode> <d:CustomerSince m:type="Edm.DateTime">2008-07-10T00:00:00</d:CustomerSince> <d:IsActive m:type="Edm.Boolean">true</d:IsActive> <d:NumOfOrders m:type="Edm.Int64">255</d:NumOfOrders> <d:PartitionKey>mypartitionkey</d:PartitionKey> <d:RowKey>myrowkey1</d:RowKey> <d:Timestamp m:type="Edm.DateTime">0001-01-01T00:00:00</d:Timestamp> </m:properties> </content> </entry> ''' xmldoc = minidom.parseString(xmlstr) xml_properties = None for entry in _get_child_nodes(xmldoc, 'entry'): for content in _get_child_nodes(entry, 'content'): # TODO: Namespace xml_properties = _get_child_nodesNS( content, METADATA_NS, 'properties') if not xml_properties: return None entity = Entity() # extract each property node and get the type from attribute and node value for xml_property in xml_properties[0].childNodes: name = _remove_prefix(xml_property.nodeName) # exclude the Timestamp since it is auto added by azure when # inserting entity. We don't want this to mix with real properties if name in ['Timestamp']: continue if xml_property.firstChild: value = xml_property.firstChild.nodeValue else: value = '' isnull = xml_property.getAttributeNS(METADATA_NS, 'null') mtype = xml_property.getAttributeNS(METADATA_NS, 'type') # if not isnull and no type info, then it is a string and we just # need the str type to hold the property. if not isnull and not mtype: _set_entity_attr(entity, name, value) elif isnull == 'true': if mtype: property = EntityProperty(mtype, None) else: property = EntityProperty('Edm.String', None) else: # need an object to hold the property conv = _ENTITY_TO_PYTHON_CONVERSIONS.get(mtype) if conv is not None: property = conv(value) else: property = EntityProperty(mtype, value) _set_entity_attr(entity, name, property) # extract id, updated and name value from feed entry and set them of # rule. for name, value in _get_entry_properties(xmlstr, True).items(): if name in ['etag']: _set_entity_attr(entity, name, value) return entity def _set_entity_attr(entity, name, value): try: setattr(entity, name, value) except UnicodeEncodeError: # Python 2 doesn't support unicode attribute names, so we'll # add them and access them directly through the dictionary entity.__dict__[name] = value def _convert_xml_to_table(xmlstr): ''' Converts the xml response to table class. Simply call convert_xml_to_entity and extract the table name, and add updated and author info ''' table = Table() entity = _convert_xml_to_entity(xmlstr) setattr(table, 'name', entity.TableName) for name, value in _get_entry_properties(xmlstr, False).items(): setattr(table, name, value) return table def _storage_error_handler(http_error): ''' Simple error handler for storage service. ''' return _general_error_handler(http_error) # make these available just from storage. from azure.storage.blobservice import BlobService from azure.storage.queueservice import QueueService from azure.storage.tableservice import TableService from azure.storage.cloudstorageaccount import CloudStorageAccount from azure.storage.sharedaccesssignature import ( SharedAccessSignature, SharedAccessPolicy, Permission, WebResource, ) ================================================ FILE: OSPatching/azure/storage/blobservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, BLOB_SERVICE_HOST_BASE, DEV_BLOB_HOST, _ERROR_VALUE_NEGATIVE, _ERROR_PAGE_BLOB_SIZE_ALIGNMENT, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _encode_base64, _get_request_body, _get_request_body_bytes_only, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _parse_simple_list, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_type_bytes, _validate_not_none, ) from azure.http import HTTPRequest from azure.storage import ( Container, ContainerEnumResults, PageList, PageRange, SignedIdentifiers, StorageServiceProperties, _convert_block_list_to_xml, _convert_response_to_block_list, _create_blob_result, _parse_blob_enum_results_list, _update_storage_blob_header, ) from azure.storage.storageclient import _StorageClient from os import path import sys if sys.version_info >= (3,): from io import BytesIO else: from cStringIO import StringIO as BytesIO # Keep this value sync with _ERROR_PAGE_BLOB_SIZE_ALIGNMENT _PAGE_SIZE = 512 class BlobService(_StorageClient): ''' This is the main class managing Blob resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=BLOB_SERVICE_HOST_BASE, dev_host=DEV_BLOB_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to https. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self._BLOB_MAX_DATA_SIZE = 64 * 1024 * 1024 self._BLOB_MAX_CHUNK_DATA_SIZE = 4 * 1024 * 1024 super(BlobService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def make_blob_url(self, container_name, blob_name, account_name=None, protocol=None, host_base=None): ''' Creates the url to access a blob. container_name: Name of container. blob_name: Name of blob. account_name: Name of the storage account. If not specified, uses the account specified when BlobService was initialized. protocol: Protocol to use: 'http' or 'https'. If not specified, uses the protocol specified when BlobService was initialized. host_base: Live host base url. If not specified, uses the host base specified when BlobService was initialized. ''' if not account_name: account_name = self.account_name if not protocol: protocol = self.protocol if not host_base: host_base = self.host_base return '{0}://{1}{2}/{3}/{4}'.format(protocol, account_name, host_base, container_name, blob_name) def list_containers(self, prefix=None, marker=None, maxresults=None, include=None): ''' The List Containers operation returns a list of the containers under the specified account. prefix: Optional. Filters the results to return only containers whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. maxresults: Optional. Specifies the maximum number of containers to return. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. set this parameter to string 'metadata' to get container's metadata. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list(response, ContainerEnumResults, "Containers", Container) def create_container(self, container_name, x_ms_meta_name_values=None, x_ms_blob_public_access=None, fail_on_exist=False): ''' Creates a new container under the specified account. If the container with the same name already exists, the operation fails. container_name: Name of container to create. x_ms_meta_name_values: Optional. A dict with name_value pairs to associate with the container as metadata. Example:{'Category':'test'} x_ms_blob_public_access: Optional. Possible values include: container, blob fail_on_exist: specify whether to throw an exception when the container exists. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def get_container_properties(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata and system properties for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_properties only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def get_container_metadata(self, container_name, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified container. The metadata will be in returned dictionary['x-ms-meta-(name)']. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_container_metadata(self, container_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets one or more user-defined name-value pairs for the specified container. container_name: Name of existing container. x_ms_meta_name_values: A dict containing name, value for metadata. Example: {'category':'test'} x_ms_lease_id: If specified, set_container_metadata only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_container_acl(self, container_name, x_ms_lease_id=None): ''' Gets the permissions for the specified container. container_name: Name of existing container. x_ms_lease_id: If specified, get_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, SignedIdentifiers) def set_container_acl(self, container_name, signed_identifiers=None, x_ms_blob_public_access=None, x_ms_lease_id=None): ''' Sets the permissions for the specified container. container_name: Name of existing container. signed_identifiers: SignedIdentifers instance x_ms_blob_public_access: Optional. Possible values include: container, blob x_ms_lease_id: If specified, set_container_acl only succeeds if the container's lease is active and matches this ID. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=acl' request.headers = [ ('x-ms-blob-public-access', _str_or_none(x_ms_blob_public_access)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ] request.body = _get_request_body( _convert_class_to_xml(signed_identifiers)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_container(self, container_name, fail_not_exist=False, x_ms_lease_id=None): ''' Marks the specified container for deletion. container_name: Name of container to delete. fail_not_exist: Specify whether to throw an exception when the container doesn't exist. x_ms_lease_id: Required if the container has an active lease. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '?restype=container' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def lease_container(self, container_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a lock on a container for delete operations. The lock duration can be 15 to 60 seconds, or can be infinite. container_name: Name of existing container. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the container has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none( x_ms_lease_duration if x_ms_lease_action == 'acquire'\ else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def list_blobs(self, container_name, prefix=None, marker=None, maxresults=None, include=None, delimiter=None): ''' Returns the list of blobs under the specified container. container_name: Name of existing container. prefix: Optional. Filters the results to return only blobs whose names begin with the specified prefix. marker: Optional. A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a marker value within the response body if the list returned was not complete. The marker value may then be used in a subsequent call to request the next set of list items. The marker value is opaque to the client. maxresults: Optional. Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not specify maxresults or specifies a value greater than 5,000, the server will return up to 5,000 items. Setting maxresults to a value less than or equal to zero results in error response code 400 (Bad Request). include: Optional. Specifies one or more datasets to include in the response. To specify more than one of these options on the URI, you must separate each option with a comma. Valid values are: snapshots: Specifies that snapshots should be included in the enumeration. Snapshots are listed from oldest to newest in the response. metadata: Specifies that blob metadata be returned in the response. uncommittedblobs: Specifies that blobs for which blocks have been uploaded, but which have not been committed using Put Block List (REST API), be included in the response. copy: Version 2012-02-12 and newer. Specifies that metadata related to any current or previous Copy Blob operation should be included in the response. delimiter: Optional. When the request includes this parameter, the operation returns a BlobPrefix element in the response body that acts as a placeholder for all blobs whose names begin with the same substring up to the appearance of the delimiter character. The delimiter may be a single character or a string. ''' _validate_not_none('container_name', container_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '?restype=container&comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('delimiter', _str_or_none(delimiter)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_blob_enum_results_list(response) def set_blob_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. You can also use this operation to set the default request version for all incoming requests that do not have a version specified. storage_service_properties: a StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_blob_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Blob service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def get_blob_properties(self, container_name, blob_name, x_ms_lease_id=None): ''' Returns all user-defined metadata, standard HTTP properties, and system properties for the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'HEAD' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def set_blob_properties(self, container_name, blob_name, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_md5=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_lease_id=None): ''' Sets system properties on the blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_blob_cache_control: Optional. Modifies the cache control string for the blob. x_ms_blob_content_type: Optional. Sets the blob's content type. x_ms_blob_content_md5: Optional. Sets the blob's MD5 hash. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. x_ms_blob_content_language: Optional. Sets the blob's content language. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=properties' request.headers = [ ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_blob(self, container_name, blob_name, blob, x_ms_blob_type, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_content_length=None, x_ms_blob_sequence_number=None): ''' Creates a new block blob or page blob, or updates the content of an existing block blob. See put_block_blob_from_* and put_page_blob_from_* for high level functions that handle the creation and upload of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: For BlockBlob: Content of blob as bytes (size < 64MB). For larger size, you must call put_block and put_block_list to set content of blob. For PageBlob: Use None and call put_page to set content of blob. x_ms_blob_type: Required. Could be BlockBlob or PageBlob. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_content_length: Required for page blobs. This header specifies the maximum size for the page blob, up to 1 TB. The page blob size must be aligned to a 512-byte boundary. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_blob_type', x_ms_blob_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-blob-type', _str_or_none(x_ms_blob_type)), ('Content-Encoding', _str_or_none(content_encoding)), ('Content-Language', _str_or_none(content_language)), ('Content-MD5', _str_or_none(content_md5)), ('Cache-Control', _str_or_none(cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-blob-content-length', _str_or_none(x_ms_blob_content_length)), ('x-ms-blob-sequence-number', _str_or_none(x_ms_blob_sequence_number)) ] request.body = _get_request_body_bytes_only('blob', blob) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file path, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_file(self, container_name, blob_name, stream, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from a file/stream, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is optional, but should be supplied for optimal performance. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) if count and count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = stream.read(count) self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, None, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) remain_bytes = count block_ids = [] block_index = 0 index = 0 while True: request_count = self._BLOB_MAX_CHUNK_DATA_SIZE\ if remain_bytes is None else min( remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) index += length remain_bytes = remain_bytes - \ length if remain_bytes else None block_id = '{0:08d}'.format(block_index) self.put_block(container_name, blob_name, data, block_id, x_ms_lease_id=x_ms_lease_id) block_ids.append(block_id) block_index += 1 if progress_callback: progress_callback(index, count) else: break self.put_block_list(container_name, blob_name, block_ids, content_md5, x_ms_blob_cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_meta_name_values, x_ms_lease_id) def put_block_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from an array of bytes, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_not_none('index', index) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index if count < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, count) data = blob[index: index + count] self.put_blob(container_name, blob_name, data, 'BlockBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id) if progress_callback: progress_callback(count, count) else: stream = BytesIO(blob) stream.seek(index) self.put_block_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_block_blob_from_text(self, container_name, blob_name, text, text_encoding='utf-8', content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, progress_callback=None): ''' Creates a new block blob from str/unicode, or updates the content of an existing block blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. text: Text to upload to the blob. text_encoding: Encoding to use to convert the text to bytes. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text', text) if not isinstance(text, bytes): _validate_not_none('text_encoding', text_encoding) text = text.encode(text_encoding) self.put_block_blob_from_bytes(container_name, blob_name, text, 0, len(text), content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, progress_callback) def put_page_blob_from_path(self, container_name, blob_name, file_path, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file path, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. file_path: Path of the file to upload as the blob content. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) count = path.getsize(file_path) with open(file_path, 'rb') as stream: self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def put_page_blob_from_file(self, container_name, blob_name, stream, count, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from a file/stream, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. stream: Opened file/stream to upload as the blob content. count: Number of bytes to read from the stream. This is required, a page blob cannot be created if the count is unknown. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) _validate_not_none('count', count) if count < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('count')) if count % _PAGE_SIZE != 0: raise TypeError(_ERROR_PAGE_BLOB_SIZE_ALIGNMENT.format(count)) if progress_callback: progress_callback(0, count) self.put_blob(container_name, blob_name, b'', 'PageBlob', content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, count, x_ms_blob_sequence_number) remain_bytes = count page_start = 0 while True: request_count = min(remain_bytes, self._BLOB_MAX_CHUNK_DATA_SIZE) data = stream.read(request_count) if data: length = len(data) remain_bytes = remain_bytes - length page_end = page_start + length - 1 self.put_page(container_name, blob_name, data, 'bytes={0}-{1}'.format(page_start, page_end), 'update', x_ms_lease_id=x_ms_lease_id) page_start = page_start + length if progress_callback: progress_callback(page_start, count) else: break def put_page_blob_from_bytes(self, container_name, blob_name, blob, index=0, count=None, content_encoding=None, content_language=None, content_md5=None, cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_blob_cache_control=None, x_ms_meta_name_values=None, x_ms_lease_id=None, x_ms_blob_sequence_number=None, progress_callback=None): ''' Creates a new page blob from an array of bytes, or updates the content of an existing page blob, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of blob to create or update. blob: Content of blob as an array of bytes. index: Start index in the array of bytes. count: Number of bytes to upload. Set to None or negative value to upload all bytes starting from index. content_encoding: Optional. Specifies which content encodings have been applied to the blob. This value is returned to the client when the Get Blob (REST API) operation is performed on the blob resource. The client can use this value when returned to decode the blob content. content_language: Optional. Specifies the natural languages used by this resource. content_md5: Optional. An MD5 hash of the blob content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). cache_control: Optional. The Blob service stores this value but does not use or modify it. x_ms_blob_content_type: Optional. Set the blob's content type. x_ms_blob_content_encoding: Optional. Set the blob's content encoding. x_ms_blob_content_language: Optional. Set the blob's content language. x_ms_blob_content_md5: Optional. Set the blob's MD5 hash. x_ms_blob_cache_control: Optional. Sets the blob's cache control. x_ms_meta_name_values: A dict containing name, value for metadata. x_ms_lease_id: Required if the blob has an active lease. x_ms_blob_sequence_number: Optional. Set for page blobs only. The sequence number is a user-controlled value that you can use to track requests. The value of the sequence number must be between 0 and 2^63 - 1. The default value is 0. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob, or None if the total size is unknown. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('blob', blob) _validate_type_bytes('blob', blob) if index < 0: raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) if count is None or count < 0: count = len(blob) - index stream = BytesIO(blob) stream.seek(index) self.put_page_blob_from_file(container_name, blob_name, stream, count, content_encoding, content_language, content_md5, cache_control, x_ms_blob_content_type, x_ms_blob_content_encoding, x_ms_blob_content_language, x_ms_blob_content_md5, x_ms_blob_cache_control, x_ms_meta_name_values, x_ms_lease_id, x_ms_blob_sequence_number, progress_callback) def get_blob(self, container_name, blob_name, snapshot=None, x_ms_range=None, x_ms_lease_id=None, x_ms_range_get_content_md5=None): ''' Reads or downloads a blob from the system, including its metadata and properties. See get_blob_to_* for high level functions that handle the download of large blobs with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_range: Optional. Return only the bytes of the blob in the specified range. x_ms_lease_id: Required if the blob has an active lease. x_ms_range_get_content_md5: Optional. When this header is set to true and specified together with the Range header, the service returns the MD5 hash for the range, as long as the range is less than or equal to 4 MB in size. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-range-get-content-md5', _str_or_none(x_ms_range_get_content_md5)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request, None) return _create_blob_result(response) def get_blob_to_path(self, container_name, blob_name, file_path, open_mode='wb', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file path, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. file_path: Path of file to write to. open_mode: Mode to use when opening the file. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('file_path', file_path) _validate_not_none('open_mode', open_mode) with open(file_path, open_mode) as stream: self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) def get_blob_to_file(self, container_name, blob_name, stream, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob to a file/stream, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. stream: Opened file/stream to write to. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('stream', stream) props = self.get_blob_properties(container_name, blob_name) blob_size = int(props['content-length']) if blob_size < self._BLOB_MAX_DATA_SIZE: if progress_callback: progress_callback(0, blob_size) data = self.get_blob(container_name, blob_name, snapshot, x_ms_lease_id=x_ms_lease_id) stream.write(data) if progress_callback: progress_callback(blob_size, blob_size) else: if progress_callback: progress_callback(0, blob_size) index = 0 while index < blob_size: chunk_range = 'bytes={0}-{1}'.format( index, index + self._BLOB_MAX_CHUNK_DATA_SIZE - 1) data = self.get_blob( container_name, blob_name, x_ms_range=chunk_range) length = len(data) index += length if length > 0: stream.write(data) if progress_callback: progress_callback(index, blob_size) if length < self._BLOB_MAX_CHUNK_DATA_SIZE: break else: break def get_blob_to_bytes(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as an array of bytes, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) stream = BytesIO() self.get_blob_to_file(container_name, blob_name, stream, snapshot, x_ms_lease_id, progress_callback) return stream.getvalue() def get_blob_to_text(self, container_name, blob_name, text_encoding='utf-8', snapshot=None, x_ms_lease_id=None, progress_callback=None): ''' Downloads a blob as unicode text, with automatic chunking and progress notifications. container_name: Name of existing container. blob_name: Name of existing blob. text_encoding: Encoding to use when decoding the blob data. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. progress_callback: Callback for progress with signature function(current, total) where current is the number of bytes transfered so far, and total is the size of the blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('text_encoding', text_encoding) result = self.get_blob_to_bytes(container_name, blob_name, snapshot, x_ms_lease_id, progress_callback) return result.decode(text_encoding) def get_blob_metadata(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Returns all user-defined metadata for the specified blob or snapshot. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix(response, prefixes=['x-ms-meta']) def set_blob_metadata(self, container_name, blob_name, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Sets user-defined metadata for the specified blob as one or more name-value pairs. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=metadata' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def lease_blob(self, container_name, blob_name, x_ms_lease_action, x_ms_lease_id=None, x_ms_lease_duration=60, x_ms_lease_break_period=None, x_ms_proposed_lease_id=None): ''' Establishes and manages a one-minute lock on a blob for write operations. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_lease_action: Required. Possible values: acquire|renew|release|break|change x_ms_lease_id: Required if the blob has an active lease. x_ms_lease_duration: Specifies the duration of the lease, in seconds, or negative one (-1) for a lease that never expires. A non-infinite lease can be between 15 and 60 seconds. A lease duration cannot be changed using renew or change. For backwards compatibility, the default is 60, and the value is only used on an acquire operation. x_ms_lease_break_period: Optional. For a break operation, this is the proposed duration of seconds that the lease should continue before it is broken, between 0 and 60 seconds. This break period is only used if it is shorter than the time remaining on the lease. If longer, the time remaining on the lease is used. A new lease will not be available before the break period has expired, but the lease may be held for longer than the break period. If this header does not appear with a break operation, a fixed-duration lease breaks after the remaining lease period elapses, and an infinite lease breaks immediately. x_ms_proposed_lease_id: Optional for acquire, required for change. Proposed lease ID, in a GUID string format. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_lease_action', x_ms_lease_action) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=lease' request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-action', _str_or_none(x_ms_lease_action)), ('x-ms-lease-duration', _str_or_none(x_ms_lease_duration\ if x_ms_lease_action == 'acquire' else None)), ('x-ms-lease-break-period', _str_or_none(x_ms_lease_break_period)), ('x-ms-proposed-lease-id', _str_or_none(x_ms_proposed_lease_id)), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-lease-id', 'x-ms-lease-time']) def snapshot_blob(self, container_name, blob_name, x_ms_meta_name_values=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None): ''' Creates a read-only snapshot of a blob. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_meta_name_values: Optional. Dict containing name and value pairs. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=snapshot' request.headers = [ ('x-ms-meta-name-values', x_ms_meta_name_values), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-snapshot', 'etag', 'last-modified']) def copy_blob(self, container_name, blob_name, x_ms_copy_source, x_ms_meta_name_values=None, x_ms_source_if_modified_since=None, x_ms_source_if_unmodified_since=None, x_ms_source_if_match=None, x_ms_source_if_none_match=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None, x_ms_lease_id=None, x_ms_source_lease_id=None): ''' Copies a blob to a destination within the storage account. container_name: Name of existing container. blob_name: Name of existing blob. x_ms_copy_source: URL up to 2 KB in length that specifies a blob. A source blob in the same account can be private, but a blob in another account must be public or accept credentials included in this URL, such as a Shared Access Signature. Examples: https://myaccount.blob.core.windows.net/mycontainer/myblob https://myaccount.blob.core.windows.net/mycontainer/myblob?snapshot=<DateTime> x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_source_if_modified_since: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. x_ms_source_if_unmodified_since: Optional. An ETag value. Specify this conditional header to copy the blob only if its ETag does not match the value specified. x_ms_source_if_match: Optional. A DateTime value. Specify this conditional header to copy the blob only if the source blob has been modified since the specified date/time. x_ms_source_if_none_match: Optional. An ETag value. Specify this conditional header to copy the source blob only if its ETag matches the value specified. if_modified_since: Optional. Datetime string. if_unmodified_since: DateTime string. if_match: Optional. Snapshot the blob only if its ETag value matches the value specified. if_none_match: Optional. An ETag value x_ms_lease_id: Required if the blob has an active lease. x_ms_source_lease_id: Optional. Specify this to perform the Copy Blob operation only if the lease ID given matches the active lease ID of the source blob. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_source', x_ms_copy_source) if x_ms_copy_source.startswith('/'): # Backwards compatibility for earlier versions of the SDK where # the copy source can be in the following formats: # - Blob in named container: # /accountName/containerName/blobName # - Snapshot in named container: # /accountName/containerName/blobName?snapshot=<DateTime> # - Blob in root container: # /accountName/blobName # - Snapshot in root container: # /accountName/blobName?snapshot=<DateTime> account, _, source =\ x_ms_copy_source.partition('/')[2].partition('/') x_ms_copy_source = self.protocol + '://' + \ account + self.host_base + '/' + source request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [ ('x-ms-copy-source', _str_or_none(x_ms_copy_source)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-source-if-modified-since', _str_or_none(x_ms_source_if_modified_since)), ('x-ms-source-if-unmodified-since', _str_or_none(x_ms_source_if_unmodified_since)), ('x-ms-source-if-match', _str_or_none(x_ms_source_if_match)), ('x-ms-source-if-none-match', _str_or_none(x_ms_source_if_none_match)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-source-lease-id', _str_or_none(x_ms_source_lease_id)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict(response) def abort_copy_blob(self, container_name, blob_name, x_ms_copy_id, x_ms_lease_id=None): ''' Aborts a pending copy_blob operation, and leaves a destination blob with zero length and full metadata. container_name: Name of destination container. blob_name: Name of destination blob. x_ms_copy_id: Copy identifier provided in the x-ms-copy-id of the original copy_blob operation. x_ms_lease_id: Required if the destination blob has an active infinite lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('x_ms_copy_id', x_ms_copy_id) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + \ _str(blob_name) + '?comp=copy©id=' + \ _str(x_ms_copy_id) request.headers = [ ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-copy-action', 'abort'), ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def delete_blob(self, container_name, blob_name, snapshot=None, x_ms_lease_id=None): ''' Marks the specified blob or snapshot for deletion. The blob is later deleted during garbage collection. To mark a specific snapshot for deletion provide the date/time of the snapshot via the snapshot parameter. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to delete. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block(self, container_name, blob_name, block, blockid, content_md5=None, x_ms_lease_id=None): ''' Creates a new block to be committed as part of a blob. container_name: Name of existing container. blob_name: Name of existing blob. block: Content of the block. blockid: Required. A value that identifies the block. The string must be less than or equal to 64 bytes in size. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block', block) _validate_not_none('blockid', blockid) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=block' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('blockid', _encode_base64(_str_or_none(blockid)))] request.body = _get_request_body_bytes_only('block', block) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def put_block_list(self, container_name, blob_name, block_list, content_md5=None, x_ms_blob_cache_control=None, x_ms_blob_content_type=None, x_ms_blob_content_encoding=None, x_ms_blob_content_language=None, x_ms_blob_content_md5=None, x_ms_meta_name_values=None, x_ms_lease_id=None): ''' Writes a blob by specifying the list of block IDs that make up the blob. In order to be written as part of a blob, a block must have been successfully written to the server in a prior Put Block (REST API) operation. container_name: Name of existing container. blob_name: Name of existing blob. block_list: A str list containing the block ids. content_md5: Optional. An MD5 hash of the block content. This hash is used to verify the integrity of the blob during transport. When this header is specified, the storage service checks the hash that has arrived with the one that was sent. x_ms_blob_cache_control: Optional. Sets the blob's cache control. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_type: Optional. Sets the blob's content type. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_encoding: Optional. Sets the blob's content encoding. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_language: Optional. Set the blob's content language. If specified, this property is stored with the blob and returned with a read request. x_ms_blob_content_md5: Optional. An MD5 hash of the blob content. Note that this hash is not validated, as the hashes for the individual blocks were validated when each was uploaded. x_ms_meta_name_values: Optional. Dict containing name and value pairs. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('block_list', block_list) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [ ('Content-MD5', _str_or_none(content_md5)), ('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)), ('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)), ('x-ms-blob-content-encoding', _str_or_none(x_ms_blob_content_encoding)), ('x-ms-blob-content-language', _str_or_none(x_ms_blob_content_language)), ('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)), ('x-ms-meta-name-values', x_ms_meta_name_values), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.body = _get_request_body( _convert_block_list_to_xml(block_list)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_block_list(self, container_name, blob_name, snapshot=None, blocklisttype=None, x_ms_lease_id=None): ''' Retrieves the list of blocks that have been uploaded as part of a block blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. Datetime to determine the time to retrieve the blocks. blocklisttype: Specifies whether to return the list of committed blocks, the list of uncommitted blocks, or both lists together. Valid values are: committed, uncommitted, or all. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=blocklist' request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.query = [ ('snapshot', _str_or_none(snapshot)), ('blocklisttype', _str_or_none(blocklisttype)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _convert_response_to_block_list(response) def put_page(self, container_name, blob_name, page, x_ms_range, x_ms_page_write, timeout=None, content_md5=None, x_ms_lease_id=None, x_ms_if_sequence_number_lte=None, x_ms_if_sequence_number_lt=None, x_ms_if_sequence_number_eq=None, if_modified_since=None, if_unmodified_since=None, if_match=None, if_none_match=None): ''' Writes a range of pages to a page blob. container_name: Name of existing container. blob_name: Name of existing blob. page: Content of the page. x_ms_range: Required. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_page_write: Required. You may specify one of the following options: update (lower case): Writes the bytes specified by the request body into the specified range. The Range and Content-Length headers must match to perform the update. clear (lower case): Clears the specified range and releases the space used in storage for that range. To clear a range, set the Content-Length header to zero, and the Range header to a value that indicates the range to clear, up to maximum blob size. timeout: the timeout parameter is expressed in seconds. content_md5: Optional. An MD5 hash of the page content. This hash is used to verify the integrity of the page during transport. When this header is specified, the storage service compares the hash of the content that has arrived with the header value that was sent. If the two hashes do not match, the operation will fail with error code 400 (Bad Request). x_ms_lease_id: Required if the blob has an active lease. x_ms_if_sequence_number_lte: Optional. If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_lt: Optional. If the blob's sequence number is less than the specified value, the request proceeds; otherwise it fails. x_ms_if_sequence_number_eq: Optional. If the blob's sequence number is equal to the specified value, the request proceeds; otherwise it fails. if_modified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has been modified since the specified date/time. If the blob has not been modified, the Blob service fails. if_unmodified_since: Optional. A DateTime value. Specify this conditional header to write the page only if the blob has not been modified since the specified date/time. If the blob has been modified, the Blob service fails. if_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value matches the value specified. If the values do not match, the Blob service fails. if_none_match: Optional. An ETag value. Specify an ETag value for this conditional header to write the page only if the blob's ETag value does not match the value specified. If the values are identical, the Blob service fails. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) _validate_not_none('page', page) _validate_not_none('x_ms_range', x_ms_range) _validate_not_none('x_ms_page_write', x_ms_page_write) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=page' request.headers = [ ('x-ms-range', _str_or_none(x_ms_range)), ('Content-MD5', _str_or_none(content_md5)), ('x-ms-page-write', _str_or_none(x_ms_page_write)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-if-sequence-number-le', _str_or_none(x_ms_if_sequence_number_lte)), ('x-ms-if-sequence-number-lt', _str_or_none(x_ms_if_sequence_number_lt)), ('x-ms-if-sequence-number-eq', _str_or_none(x_ms_if_sequence_number_eq)), ('If-Modified-Since', _str_or_none(if_modified_since)), ('If-Unmodified-Since', _str_or_none(if_unmodified_since)), ('If-Match', _str_or_none(if_match)), ('If-None-Match', _str_or_none(if_none_match)) ] request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body_bytes_only('page', page) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) self._perform_request(request) def get_page_ranges(self, container_name, blob_name, snapshot=None, range=None, x_ms_range=None, x_ms_lease_id=None): ''' Retrieves the page ranges for a blob. container_name: Name of existing container. blob_name: Name of existing blob. snapshot: Optional. The snapshot parameter is an opaque DateTime value that, when present, specifies the blob snapshot to retrieve information from. range: Optional. Specifies the range of bytes over which to list ranges, inclusively. If omitted, then all ranges for the blob are returned. x_ms_range: Optional. Specifies the range of bytes to be written as a page. Both the start and end of the range must be specified. Must be in format: bytes=startByte-endByte. Given that pages must be aligned with 512-byte boundaries, the start offset must be a modulus of 512 and the end offset must be a modulus of 512-1. Examples of valid byte ranges are 0-511, 512-1023, etc. x_ms_lease_id: Required if the blob has an active lease. ''' _validate_not_none('container_name', container_name) _validate_not_none('blob_name', blob_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + \ _str(container_name) + '/' + _str(blob_name) + '?comp=pagelist' request.headers = [ ('Range', _str_or_none(range)), ('x-ms-range', _str_or_none(x_ms_range)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)) ] request.query = [('snapshot', _str_or_none(snapshot))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_blob_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_simple_list(response, PageList, PageRange, "page_ranges") ================================================ FILE: OSPatching/azure/storage/cloudstorageaccount.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure.storage.blobservice import BlobService from azure.storage.tableservice import TableService from azure.storage.queueservice import QueueService class CloudStorageAccount(object): """ Provides a factory for creating the blob, queue, and table services with a common account name and account key. Users can either use the factory or can construct the appropriate service directly. """ def __init__(self, account_name=None, account_key=None): self.account_name = account_name self.account_key = account_key def create_blob_service(self): return BlobService(self.account_name, self.account_key) def create_table_service(self): return TableService(self.account_name, self.account_key) def create_queue_service(self): return QueueService(self.account_name, self.account_key) ================================================ FILE: OSPatching/azure/storage/queueservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureConflictError, WindowsAzureError, DEV_QUEUE_HOST, QUEUE_SERVICE_HOST_BASE, xml_escape, _convert_class_to_xml, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_enum_results_list, _parse_response, _parse_response_for_dict_filter, _parse_response_for_dict_prefix, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, _ERROR_CONFLICT, ) from azure.http import ( HTTPRequest, HTTP_RESPONSE_NO_CONTENT, ) from azure.storage import ( Queue, QueueEnumResults, QueueMessagesList, StorageServiceProperties, _update_storage_queue_header, ) from azure.storage.storageclient import _StorageClient class QueueService(_StorageClient): ''' This is the main class managing queue resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=QUEUE_SERVICE_HOST_BASE, dev_host=DEV_QUEUE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(QueueService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def get_queue_service_properties(self, timeout=None): ''' Gets the properties of a storage account's Queue Service, including Windows Azure Storage Analytics. timeout: Optional. The timeout parameter is expressed in seconds. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def list_queues(self, prefix=None, marker=None, maxresults=None, include=None): ''' Lists all of the queues in a given storage account. prefix: Filters the results to return only queues with names that begin with the specified prefix. marker: A string value that identifies the portion of the list to be returned with the next list operation. The operation returns a NextMarker element within the response body if the list returned was not complete. This value may then be used as a query parameter in a subsequent call to request the next portion of the list of queues. The marker value is opaque to the client. maxresults: Specifies the maximum number of queues to return. If maxresults is not specified, the server will return up to 5,000 items. include: Optional. Include this parameter to specify that the container's metadata be returned as part of the response body. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?comp=list' request.query = [ ('prefix', _str_or_none(prefix)), ('marker', _str_or_none(marker)), ('maxresults', _int_or_none(maxresults)), ('include', _str_or_none(include)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_enum_results_list( response, QueueEnumResults, "Queues", Queue) def create_queue(self, queue_name, x_ms_meta_name_values=None, fail_on_exist=False): ''' Creates a queue under the given account. queue_name: name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. fail_on_exist: Specify whether throw exception when queue exists. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_on_exist: try: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: return False return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: response = self._perform_request(request) if response.status == HTTP_RESPONSE_NO_CONTENT: raise WindowsAzureConflictError( _ERROR_CONFLICT.format(response.message)) return True def delete_queue(self, queue_name, fail_not_exist=False): ''' Permanently deletes the specified queue. queue_name: Name of the queue. fail_not_exist: Specify whether throw exception when queue doesn't exist. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_queue_metadata(self, queue_name): ''' Retrieves user-defined metadata and queue properties on the specified queue. Metadata is associated with the queue as name-values pairs. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_prefix( response, prefixes=['x-ms-meta', 'x-ms-approximate-messages-count']) def set_queue_metadata(self, queue_name, x_ms_meta_name_values=None): ''' Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. queue_name: Name of the queue. x_ms_meta_name_values: Optional. A dict containing name-value pairs to associate with the queue as metadata. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + _str(queue_name) + '?comp=metadata' request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def put_message(self, queue_name, message_text, visibilitytimeout=None, messagettl=None): ''' Adds a new message to the back of the message queue. A visibility timeout can also be specified to make the message invisible until the visibility timeout expires. A message must be in a format that can be included in an XML request with UTF-8 encoding. The encoded message can be up to 64KB in size for versions 2011-08-18 and newer, or 8KB in size for previous versions. queue_name: Name of the queue. message_text: Message content. visibilitytimeout: Optional. If not specified, the default value is 0. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibilitytimeout should be set to a value smaller than the time-to-live value. messagettl: Optional. Specifies the time-to-live interval for the message, in seconds. The maximum time-to-live allowed is 7 days. If this parameter is omitted, the default time-to-live is 7 days. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_text', message_text) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('visibilitytimeout', _str_or_none(visibilitytimeout)), ('messagettl', _str_or_none(messagettl)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def get_messages(self, queue_name, numofmessages=None, visibilitytimeout=None): ''' Retrieves one or more messages from the front of the queue. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to retrieve from the queue, up to a maximum of 32. If fewer are visible, the visible messages are returned. By default, a single message is retrieved from the queue with this operation. visibilitytimeout: Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 1 second, and cannot be larger than 7 days, or larger than 2 hours on REST protocol versions prior to version 2011-08-18. The visibility timeout of a message can be set to a value later than the expiry time. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.query = [ ('numofmessages', _str_or_none(numofmessages)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def peek_messages(self, queue_name, numofmessages=None): ''' Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. queue_name: Name of the queue. numofmessages: Optional. A nonzero integer value that specifies the number of messages to peek from the queue, up to a maximum of 32. By default, a single message is peeked from the queue with this operation. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages?peekonly=true' request.query = [('numofmessages', _str_or_none(numofmessages))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response(response, QueueMessagesList) def delete_message(self, queue_name, message_id, popreceipt): ''' Deletes the specified message. queue_name: Name of the queue. message_id: Message to delete. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('popreceipt', popreceipt) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [('popreceipt', _str_or_none(popreceipt))] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def clear_messages(self, queue_name): ''' Deletes all messages from the specified queue. queue_name: Name of the queue. ''' _validate_not_none('queue_name', queue_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + _str(queue_name) + '/messages' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) def update_message(self, queue_name, message_id, message_text, popreceipt, visibilitytimeout): ''' Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. queue_name: Name of the queue. message_id: Message to update. message_text: Content of message. popreceipt: Required. A valid pop receipt value returned from an earlier call to the Get Messages or Update Message operation. visibilitytimeout: Required. Specifies the new visibility timeout value, in seconds, relative to server time. The new value must be larger than or equal to 0, and cannot be larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. A message can be updated until it has been deleted or has expired. ''' _validate_not_none('queue_name', queue_name) _validate_not_none('message_id', message_id) _validate_not_none('message_text', message_text) _validate_not_none('popreceipt', popreceipt) _validate_not_none('visibilitytimeout', visibilitytimeout) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(queue_name) + '/messages/' + _str(message_id) + '' request.query = [ ('popreceipt', _str_or_none(popreceipt)), ('visibilitytimeout', _str_or_none(visibilitytimeout)) ] request.body = _get_request_body( '<?xml version="1.0" encoding="utf-8"?> \ <QueueMessage> \ <MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \ </QueueMessage>') request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) response = self._perform_request(request) return _parse_response_for_dict_filter( response, filter=['x-ms-popreceipt', 'x-ms-time-next-visible']) def set_queue_service_properties(self, storage_service_properties, timeout=None): ''' Sets the properties of a storage account's Queue service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. timeout: Optional. The timeout parameter is expressed in seconds. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.query = [('timeout', _int_or_none(timeout))] request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_queue_header( request, self.account_name, self.account_key) self._perform_request(request) ================================================ FILE: OSPatching/azure/storage/sharedaccesssignature.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import _sign_string, url_quote from azure.storage import X_MS_VERSION #------------------------------------------------------------------------- # Constants for the share access signature SIGNED_START = 'st' SIGNED_EXPIRY = 'se' SIGNED_RESOURCE = 'sr' SIGNED_PERMISSION = 'sp' SIGNED_IDENTIFIER = 'si' SIGNED_SIGNATURE = 'sig' SIGNED_VERSION = 'sv' RESOURCE_BLOB = 'b' RESOURCE_CONTAINER = 'c' SIGNED_RESOURCE_TYPE = 'resource' SHARED_ACCESS_PERMISSION = 'permission' #-------------------------------------------------------------------------- class WebResource(object): ''' Class that stands for the resource to get the share access signature path: the resource path. properties: dict of name and values. Contains 2 item: resource type and permission request_url: the url of the webresource include all the queries. ''' def __init__(self, path=None, request_url=None, properties=None): self.path = path self.properties = properties or {} self.request_url = request_url class Permission(object): ''' Permission class. Contains the path and query_string for the path. path: the resource path query_string: dict of name, values. Contains SIGNED_START, SIGNED_EXPIRY SIGNED_RESOURCE, SIGNED_PERMISSION, SIGNED_IDENTIFIER, SIGNED_SIGNATURE name values. ''' def __init__(self, path=None, query_string=None): self.path = path self.query_string = query_string class SharedAccessPolicy(object): ''' SharedAccessPolicy class. ''' def __init__(self, access_policy, signed_identifier=None): self.id = signed_identifier self.access_policy = access_policy class SharedAccessSignature(object): ''' The main class used to do the signing and generating the signature. account_name: the storage account name used to generate shared access signature account_key: the access key to genenerate share access signature permission_set: the permission cache used to signed the request url. ''' def __init__(self, account_name, account_key, permission_set=None): self.account_name = account_name self.account_key = account_key self.permission_set = permission_set def generate_signed_query_string(self, path, resource_type, shared_access_policy, version=X_MS_VERSION): ''' Generates the query string for path, resource type and shared access policy. path: the resource resource_type: could be blob or container shared_access_policy: shared access policy version: x-ms-version for storage service, or None to get a signed query string compatible with pre 2012-02-12 clients, where the version is not included in the query string. ''' query_string = {} if shared_access_policy.access_policy.start: query_string[ SIGNED_START] = shared_access_policy.access_policy.start if version: query_string[SIGNED_VERSION] = version query_string[SIGNED_EXPIRY] = shared_access_policy.access_policy.expiry query_string[SIGNED_RESOURCE] = resource_type query_string[ SIGNED_PERMISSION] = shared_access_policy.access_policy.permission if shared_access_policy.id: query_string[SIGNED_IDENTIFIER] = shared_access_policy.id query_string[SIGNED_SIGNATURE] = self._generate_signature( path, shared_access_policy, version) return query_string def sign_request(self, web_resource): ''' sign request to generate request_url with sharedaccesssignature info for web_resource.''' if self.permission_set: for shared_access_signature in self.permission_set: if self._permission_matches_request( shared_access_signature, web_resource, web_resource.properties[ SIGNED_RESOURCE_TYPE], web_resource.properties[SHARED_ACCESS_PERMISSION]): if web_resource.request_url.find('?') == -1: web_resource.request_url += '?' else: web_resource.request_url += '&' web_resource.request_url += self._convert_query_string( shared_access_signature.query_string) break return web_resource def _convert_query_string(self, query_string): ''' Converts query string to str. The order of name, values is very important and can't be wrong.''' convert_str = '' if SIGNED_START in query_string: convert_str += SIGNED_START + '=' + \ url_quote(query_string[SIGNED_START]) + '&' convert_str += SIGNED_EXPIRY + '=' + \ url_quote(query_string[SIGNED_EXPIRY]) + '&' convert_str += SIGNED_PERMISSION + '=' + \ query_string[SIGNED_PERMISSION] + '&' convert_str += SIGNED_RESOURCE + '=' + \ query_string[SIGNED_RESOURCE] + '&' if SIGNED_IDENTIFIER in query_string: convert_str += SIGNED_IDENTIFIER + '=' + \ query_string[SIGNED_IDENTIFIER] + '&' if SIGNED_VERSION in query_string: convert_str += SIGNED_VERSION + '=' + \ query_string[SIGNED_VERSION] + '&' convert_str += SIGNED_SIGNATURE + '=' + \ url_quote(query_string[SIGNED_SIGNATURE]) + '&' return convert_str def _generate_signature(self, path, shared_access_policy, version): ''' Generates signature for a given path and shared access policy. ''' def get_value_to_append(value, no_new_line=False): return_value = '' if value: return_value = value if not no_new_line: return_value += '\n' return return_value if path[0] != '/': path = '/' + path canonicalized_resource = '/' + self.account_name + path # Form the string to sign from shared_access_policy and canonicalized # resource. The order of values is important. string_to_sign = \ (get_value_to_append(shared_access_policy.access_policy.permission) + get_value_to_append(shared_access_policy.access_policy.start) + get_value_to_append(shared_access_policy.access_policy.expiry) + get_value_to_append(canonicalized_resource)) if version: string_to_sign += get_value_to_append(shared_access_policy.id) string_to_sign += get_value_to_append(version, True) else: string_to_sign += get_value_to_append(shared_access_policy.id, True) return self._sign(string_to_sign) def _permission_matches_request(self, shared_access_signature, web_resource, resource_type, required_permission): ''' Check whether requested permission matches given shared_access_signature, web_resource and resource type. ''' required_resource_type = resource_type if required_resource_type == RESOURCE_BLOB: required_resource_type += RESOURCE_CONTAINER for name, value in shared_access_signature.query_string.items(): if name == SIGNED_RESOURCE and \ required_resource_type.find(value) == -1: return False elif name == SIGNED_PERMISSION and \ required_permission.find(value) == -1: return False return web_resource.path.find(shared_access_signature.path) != -1 def _sign(self, string_to_sign): ''' use HMAC-SHA256 to sign the string and convert it as base64 encoded string. ''' return _sign_string(self.account_key, string_to_sign) ================================================ FILE: OSPatching/azure/storage/storageclient.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- import os import sys from azure import ( WindowsAzureError, DEV_ACCOUNT_NAME, DEV_ACCOUNT_KEY, _ERROR_STORAGE_MISSING_INFO, ) from azure.http import HTTPError from azure.http.httpclient import _HTTPClient from azure.storage import _storage_error_handler #-------------------------------------------------------------------------- # constants for azure app setting environment variables AZURE_STORAGE_ACCOUNT = 'AZURE_STORAGE_ACCOUNT' AZURE_STORAGE_ACCESS_KEY = 'AZURE_STORAGE_ACCESS_KEY' EMULATED = 'EMULATED' #-------------------------------------------------------------------------- class _StorageClient(object): ''' This is the base class for BlobManager, TableManager and QueueManager. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base='', dev_host=''): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' self.account_name = account_name self.account_key = account_key self.requestid = None self.protocol = protocol self.host_base = host_base self.dev_host = dev_host # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. self.use_local_storage = False # check whether it is run in emulator. if EMULATED in os.environ: self.is_emulated = os.environ[EMULATED].lower() != 'false' else: self.is_emulated = False # get account_name and account key. If they are not set when # constructing, get the account and key from environment variables if # the app is not run in azure emulator or use default development # storage account and key if app is run in emulator. if not self.account_name or not self.account_key: if self.is_emulated: self.account_name = DEV_ACCOUNT_NAME self.account_key = DEV_ACCOUNT_KEY self.protocol = 'http' self.use_local_storage = True else: self.account_name = os.environ.get(AZURE_STORAGE_ACCOUNT) self.account_key = os.environ.get(AZURE_STORAGE_ACCESS_KEY) if not self.account_name or not self.account_key: raise WindowsAzureError(_ERROR_STORAGE_MISSING_INFO) self._httpclient = _HTTPClient( service_instance=self, account_key=self.account_key, account_name=self.account_name, protocol=self.protocol) self._batchclient = None self._filter = self._perform_request_worker def with_filter(self, filter): ''' Returns a new service which will process requests with the specified filter. Filtering operations can include logging, automatic retrying, etc... The filter is a lambda which receives the HTTPRequest and another lambda. The filter can perform any pre-processing on the request, pass it off to the next lambda, and then perform any post-processing on the response. ''' res = type(self)(self.account_name, self.account_key, self.protocol) old_filter = self._filter def new_filter(request): return filter(request, old_filter) res._filter = new_filter return res def set_proxy(self, host, port, user=None, password=None): ''' Sets the proxy server host and port for the HTTP CONNECT Tunnelling. host: Address of the proxy. Ex: '192.168.0.100' port: Port of the proxy. Ex: 6000 user: User for proxy authorization. password: Password for proxy authorization. ''' self._httpclient.set_proxy(host, port, user, password) def _get_host(self): if self.use_local_storage: return self.dev_host else: return self.account_name + self.host_base def _perform_request_worker(self, request): return self._httpclient.perform_request(request) def _perform_request(self, request, text_encoding='utf-8'): ''' Sends the request and return response. Catches HTTPError and hand it to error handler ''' try: if self._batchclient is not None: return self._batchclient.insert_request_to_batch(request) else: resp = self._filter(request) if sys.version_info >= (3,) and isinstance(resp, bytes) and \ text_encoding: resp = resp.decode(text_encoding) except HTTPError as ex: _storage_error_handler(ex) return resp ================================================ FILE: OSPatching/azure/storage/tableservice.py ================================================ #------------------------------------------------------------------------- # Copyright (c) Microsoft. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-------------------------------------------------------------------------- from azure import ( WindowsAzureError, TABLE_SERVICE_HOST_BASE, DEV_TABLE_HOST, _convert_class_to_xml, _convert_response_to_feeds, _dont_fail_not_exist, _dont_fail_on_exist, _get_request_body, _int_or_none, _parse_response, _parse_response_for_dict, _parse_response_for_dict_filter, _str, _str_or_none, _update_request_uri_query_local_storage, _validate_not_none, ) from azure.http import HTTPRequest from azure.http.batchclient import _BatchClient from azure.storage import ( StorageServiceProperties, _convert_entity_to_xml, _convert_response_to_entity, _convert_table_to_xml, _convert_xml_to_entity, _convert_xml_to_table, _sign_storage_table_request, _update_storage_table_header, ) from azure.storage.storageclient import _StorageClient class TableService(_StorageClient): ''' This is the main class managing Table resources. ''' def __init__(self, account_name=None, account_key=None, protocol='https', host_base=TABLE_SERVICE_HOST_BASE, dev_host=DEV_TABLE_HOST): ''' account_name: your storage account name, required for all operations. account_key: your storage account key, required for all operations. protocol: Optional. Protocol. Defaults to http. host_base: Optional. Live host base url. Defaults to Azure url. Override this for on-premise. dev_host: Optional. Dev host url. Defaults to localhost. ''' super(TableService, self).__init__( account_name, account_key, protocol, host_base, dev_host) def begin_batch(self): if self._batchclient is None: self._batchclient = _BatchClient( service_instance=self, account_key=self.account_key, account_name=self.account_name) return self._batchclient.begin_batch() def commit_batch(self): try: ret = self._batchclient.commit_batch() finally: self._batchclient = None return ret def cancel_batch(self): self._batchclient = None def get_table_service_properties(self): ''' Gets the properties of a storage account's Table service, including Windows Azure Storage Analytics. ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response(response, StorageServiceProperties) def set_table_service_properties(self, storage_service_properties): ''' Sets the properties of a storage account's Table Service, including Windows Azure Storage Analytics. storage_service_properties: StorageServiceProperties object. ''' _validate_not_none('storage_service_properties', storage_service_properties) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/?restype=service&comp=properties' request.body = _get_request_body( _convert_class_to_xml(storage_service_properties)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict(response) def query_tables(self, table_name=None, top=None, next_table_name=None): ''' Returns a list of tables under the specified account. table_name: Optional. The specific table to query. top: Optional. Maximum number of tables to return. next_table_name: Optional. When top is used, the next table name is stored in result.x_ms_continuation['NextTableName'] ''' request = HTTPRequest() request.method = 'GET' request.host = self._get_host() if table_name is not None: uri_part_table_name = "('" + table_name + "')" else: uri_part_table_name = "" request.path = '/Tables' + uri_part_table_name + '' request.query = [ ('$top', _int_or_none(top)), ('NextTableName', _str_or_none(next_table_name)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_table) def create_table(self, table, fail_on_exist=False): ''' Creates a new table in the storage account. table: Name of the table to create. Table name may contain only alphanumeric characters and cannot begin with a numeric character. It is case-insensitive and must be from 3 to 63 characters long. fail_on_exist: Specify whether throw exception when table exists. ''' _validate_not_none('table', table) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/Tables' request.body = _get_request_body(_convert_table_to_xml(table)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_on_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_on_exist(ex) return False else: self._perform_request(request) return True def delete_table(self, table_name, fail_not_exist=False): ''' table_name: Name of the table to delete. fail_not_exist: Specify whether throw exception when table doesn't exist. ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/Tables(\'' + _str(table_name) + '\')' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) if not fail_not_exist: try: self._perform_request(request) return True except WindowsAzureError as ex: _dont_fail_not_exist(ex) return False else: self._perform_request(request) return True def get_entity(self, table_name, partition_key, row_key, select=''): ''' Get an entity in a table; includes the $select options. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. select: Property names to select. ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('select', select) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + \ '(PartitionKey=\'' + _str(partition_key) + \ '\',RowKey=\'' + \ _str(row_key) + '\')?$select=' + \ _str(select) + '' request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def query_entities(self, table_name, filter=None, select=None, top=None, next_partition_key=None, next_row_key=None): ''' Get entities in a table; includes the $filter and $select options. table_name: Table to query. filter: Optional. Filter as described at http://msdn.microsoft.com/en-us/library/windowsazure/dd894031.aspx select: Optional. Property names to select from the entities. top: Optional. Maximum number of entities to return. next_partition_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextPartitionKey'] next_row_key: Optional. When top is used, the next partition key is stored in result.x_ms_continuation['NextRowKey'] ''' _validate_not_none('table_name', table_name) request = HTTPRequest() request.method = 'GET' request.host = self._get_host() request.path = '/' + _str(table_name) + '()' request.query = [ ('$filter', _str_or_none(filter)), ('$select', _str_or_none(select)), ('$top', _int_or_none(top)), ('NextPartitionKey', _str_or_none(next_partition_key)), ('NextRowKey', _str_or_none(next_row_key)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_feeds(response, _convert_xml_to_entity) def insert_entity(self, table_name, entity, content_type='application/atom+xml'): ''' Inserts a new entity into a table. table_name: Table name. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'POST' request.host = self._get_host() request.path = '/' + _str(table_name) + '' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _convert_response_to_entity(response) def update_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity in a table. The Update Entity operation replaces the entire entity and can be used to remove properties. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml', if_match='*'): ''' Updates an existing entity by updating the entity's properties. This operation does not replace the existing entity as the Update Entity operation does. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Can be a dict format or entity object. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the merge should be performed. To force an unconditional merge, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def delete_entity(self, table_name, partition_key, row_key, content_type='application/atom+xml', if_match='*'): ''' Deletes an existing entity in a table. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. content_type: Required. Must be set to application/atom+xml if_match: Optional. Specifies the condition for which the delete should be performed. To force an unconditional delete, set to the wildcard character (*). ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('content_type', content_type) _validate_not_none('if_match', if_match) request = HTTPRequest() request.method = 'DELETE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [ ('Content-Type', _str_or_none(content_type)), ('If-Match', _str_or_none(if_match)) ] request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) self._perform_request(request) def insert_or_replace_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Replaces an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'PUT' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def insert_or_merge_entity(self, table_name, partition_key, row_key, entity, content_type='application/atom+xml'): ''' Merges an existing entity or inserts a new entity if it does not exist in the table. Because this operation can insert or update an entity, it is also known as an "upsert" operation. table_name: Table name. partition_key: PartitionKey of the entity. row_key: RowKey of the entity. entity: Required. The entity object to insert. Could be a dict format or entity object. content_type: Required. Must be set to application/atom+xml ''' _validate_not_none('table_name', table_name) _validate_not_none('partition_key', partition_key) _validate_not_none('row_key', row_key) _validate_not_none('entity', entity) _validate_not_none('content_type', content_type) request = HTTPRequest() request.method = 'MERGE' request.host = self._get_host() request.path = '/' + \ _str(table_name) + '(PartitionKey=\'' + \ _str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')' request.headers = [('Content-Type', _str_or_none(content_type))] request.body = _get_request_body(_convert_entity_to_xml(entity)) request.path, request.query = _update_request_uri_query_local_storage( request, self.use_local_storage) request.headers = _update_storage_table_header(request) response = self._perform_request(request) return _parse_response_for_dict_filter(response, filter=['etag']) def _perform_request_worker(self, request): auth = _sign_storage_table_request(request, self.account_name, self.account_key) request.headers.append(('Authorization', auth)) return self._httpclient.perform_request(request) ================================================ FILE: OSPatching/check.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import datetime def main(): intervalOfWeeks = int(sys.argv[1]) if intervalOfWeeks == 1: sys.exit(0) history_scheduled = os.path.join(os.path.dirname(sys.argv[0]), 'scheduled/history') today = datetime.date.today() today_dayOfWeek = today.strftime('%a') last_scheduled_date = None with open(history_scheduled) as f: lines = f.readlines() lines.reverse() for line in lines: line = line.strip() if line.endswith(today_dayOfWeek): last_scheduled_date = datetime.datetime.strptime(line, '%Y-%m-%d %a') break if (last_scheduled_date is not None and last_scheduled_date.date() + datetime.timedelta(days=intervalOfWeeks*7) > today): sys.exit(1) else: sys.exit(0) if __name__ == '__main__': main() ================================================ FILE: OSPatching/handler.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import re import time import json import tempfile import urllib2 import urlparse import platform import shutil import traceback import logging from azure.storage import BlobService from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util from patch import * # Global variables definition ExtensionShortName = "DSCForLinux" DownloadDirectory = 'download' idleTestScriptName = "idleTest.py" healthyTestScriptName = "healthyTest.py" def install(): hutil.do_parse_context('Install') try: MyPatching.install() hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception, e: hutil.error("Failed to install the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '0', 'Install Failed.') def enable(): hutil.log("WARNING: The OSPatching extension for Linux has been deprecated. " "Please see the GitHub project " "(https://github.com/Azure/azure-linux-extensions/tree/master/OSPatching) " "for more information.") hutil.do_parse_context('Enable') try: protected_settings = hutil.get_protected_settings() public_settings = hutil.get_public_settings() if protected_settings: settings = protected_settings.copy() else: settings = dict() if public_settings: settings.update(public_settings) MyPatching.parse_settings(settings) # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() oneoff = settings.get("oneoff") download_customized_vmstatustest() copy_vmstatustestscript(hutil.get_seq_no(), oneoff) MyPatching.enable() current_config = MyPatching.get_current_config() hutil.do_exit(0, 'Enable', 'warning', '0', 'Enable Succeeded. OSPatching is deprecated. See https://github.com/Azure/azure-linux-extensions/tree/master/OSPatching for more info. Current Configuration: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.error("Failed to enable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable Failed. Current Configuation: ' + current_config) def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded.') def disable(): hutil.do_parse_context('Disable') try: MyPatching.disable() hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded.') except Exception, e: hutil.error("Failed to disable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Disable', 'error', '0', 'Disable Failed.') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded.') def download(): hutil.do_parse_context('Download') try: protected_settings = hutil.get_protected_settings() public_settings = hutil.get_public_settings() if protected_settings: settings = protected_settings.copy() else: settings = dict() if public_settings: settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.download() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Download Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.error("Failed to download updates with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Download Failed. Current Configuation: ' + current_config) def patch(): hutil.do_parse_context('Patch') try: protected_settings = hutil.get_protected_settings() public_settings = hutil.get_public_settings() if protected_settings: settings = protected_settings.copy() else: settings = dict() if public_settings: settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.patch() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Patch Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.error("Failed to patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Patch Failed. Current Configuation: ' + current_config) def oneoff(): hutil.do_parse_context('Oneoff') try: protected_settings = hutil.get_protected_settings() public_settings = hutil.get_public_settings() if protected_settings: settings = protected_settings.copy() else: settings = dict() if public_settings: settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.patch_one_off() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Oneoff Patch Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.error("Failed to one-off patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Oneoff Patch Failed. Current Configuation: ' + current_config) def download_files(hutil): protected_settings = hutil.get_protected_settings() public_settings = hutil.get_public_settings() if protected_settings: settings = protected_settings.copy() else: settings = dict() if public_settings: settings.update(public_settings) local = settings.get("vmStatusTest", dict()).get("local", "") if str(local).lower() == "true": local = True elif str(local).lower() == "false": local = False else: hutil.log("WARNING: The parameter \"local\" " "is empty or invalid. Set it as False. Continue...") local = False idle_test_script = settings.get("vmStatusTest", dict()).get('idleTestScript') healthy_test_script = settings.get("vmStatusTest", dict()).get('healthyTestScript') if (not idle_test_script and not healthy_test_script): hutil.log("WARNING: The parameter \"idleTestScript\" and \"healthyTestScript\" " "are both empty. Exit downloading VMStatusTest scripts...") return elif local: if (idle_test_script and idle_test_script.startswith("http")) or \ (healthy_test_script and healthy_test_script.startswith("http")): hutil.log("WARNING: The parameter \"idleTestScript\" or \"healthyTestScript\" " "should not be uri. Exit downloading VMStatusTest scripts...") return elif not local: if (idle_test_script and not idle_test_script.startswith("http")) or \ (healthy_test_script and not healthy_test_script.startswith("http")): hutil.log("WARNING: The parameter \"idleTestScript\" or \"healthyTestScript\" " "should be uri. Exit downloading VMStatusTest scripts...") return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading VMStatusTest scripts...') vmStatusTestScripts = dict() vmStatusTestScripts[idle_test_script] = idleTestScriptName vmStatusTestScripts[healthy_test_script] = healthyTestScriptName if local: hutil.log("Saving VMStatusTest scripts from user's configurations...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = save_local_file(src, dst, hutil) preprocess_files(file_path, hutil) return storage_account_name = None storage_account_key = None if settings: storage_account_name = settings.get("storageAccountName", "").strip() storage_account_key = settings.get("storageAccountKey", "").strip() if storage_account_name and storage_account_key: hutil.log("Downloading VMStatusTest scripts from azure storage...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_blob(storage_account_name, storage_account_key, src, dst, hutil) preprocess_files(file_path, hutil) elif not(storage_account_name or storage_account_key): hutil.log("No azure storage account and key specified in protected " "settings. Downloading VMStatusTest scripts from external links...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_external_file(src, dst, hutil) preprocess_files(file_path, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account or storage key is not provided" hutil.error(error_msg) raise ValueError(error_msg) def download_blob(storage_account_name, storage_account_key, blob_uri, dst, hutil): seqNo = hutil.get_seq_no() container_name = get_container_name_from_uri(blob_uri) blob_name = get_blob_name_from_uri(blob_uri) download_dir = prepare_download_dir(seqNo) download_path = os.path.join(download_dir, dst) #Guest agent already ensure the plugin is enabled one after another. #The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key) try: hutil.log("Downloading to {0}".format(download_path)) blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception, e: hutil.error(("Failed to download blob with uri:{0} " "with error {1}").format(blob_uri,e)) raise return download_path def download_external_file(uri, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: hutil.log("Downloading to {0}".format(file_path)) download_and_save_file(uri, file_path) except Exception, e: hutil.error(("Failed to download external file with uri:{0} " "with error {1}").format(uri, e)) raise return file_path def save_local_file(src, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: hutil.log("Downloading to {0}".format(file_path)) waagent.SetFileContents(file_path, src) except Exception, e: hutil.error(("Failed to save file from user's configuration " "with error {0}").format(e)) raise return file_path def preprocess_files(file_path, hutil): """ Preprocess the text file. If it is a binary file, skip it. """ is_text, code_type = is_text_file(file_path) if is_text: dos2unix(file_path) hutil.log("Converting text files from DOS to Unix formats: Done") if code_type in ['UTF-8', 'UTF-16LE', 'UTF-16BE']: remove_bom(file_path) hutil.log("Removing BOM: Done") def is_text_file(file_path): with open(file_path, 'rb') as f: contents = f.read(512) return is_text(contents) def is_text(contents): supported_encoding = ['ascii', 'UTF-8', 'UTF-16LE', 'UTF-16BE'] # Openlogic and Oracle distros don't have python-chardet waagent.Run('yum -y install python-chardet', False) import chardet code_type = chardet.detect(contents)['encoding'] if code_type in supported_encoding: return True, code_type else: return False, code_type def dos2unix(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rU') as f: contents = f.read() f_temp.write(contents) f_temp.close() shutil.move(temp_file_path, file_path) def remove_bom(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rb') as f: contents = f.read() for encoding in ["utf-8-sig", "utf-16"]: try: f_temp.write(contents.decode(encoding).encode('utf-8')) break except UnicodeDecodeError: continue f_temp.close() shutil.move(temp_file_path, file_path) def download_and_save_file(uri, file_path): src = urllib2.urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_path_from_uri(uriStr): uri = urlparse.urlparse(uriStr) return uri.path def get_blob_name_from_uri(uri): return get_properties_from_uri(uri)['blob_name'] def get_container_name_from_uri(uri): return get_properties_from_uri(uri)['container_name'] def get_properties_from_uri(uri): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.error("Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def download_customized_vmstatustest(): download_dir = prepare_download_dir(hutil.get_seq_no()) maxRetry = 2 for retry in range(0, maxRetry + 1): try: download_files(hutil) break except Exception, e: hutil.error("Failed to download files, retry=" + str(retry) + ", maxRetry=" + str(maxRetry)) if retry != maxRetry: hutil.log("Sleep 10 seconds") time.sleep(10) else: raise def copy_vmstatustestscript(seqNo, oneoff): src_dir = prepare_download_dir(seqNo) for filename in (idleTestScriptName, healthyTestScriptName): src = os.path.join(src_dir, filename) if oneoff is not None and str(oneoff).lower() == "true": dst = "oneoff" else: dst = "scheduled" dst = os.path.join(os.getcwd(), dst) current_vmstatustestscript = os.path.join(dst, filename) if os.path.isfile(current_vmstatustestscript): os.remove(current_vmstatustestscript) # Remove the .pyc file if os.path.isfile(current_vmstatustestscript+'c'): os.remove(current_vmstatustestscript+'c') if os.path.isfile(src): shutil.copy(src, dst) # Main function is the only entrance to this extension handler def main(): waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error) global MyPatching MyPatching = GetMyPatching(hutil) if MyPatching is None: sys.exit(1) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(download)", a): download() elif re.match("^([-/]*)(patch)", a): patch() elif re.match("^([-/]*)(oneoff)", a): oneoff() if __name__ == '__main__': main() ================================================ FILE: OSPatching/manifest.xml ================================================ <?xml version='1.0' encoding='utf-8' ?> <ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure"> <ProviderNameSpace>Microsoft.OSTCExtensions</ProviderNameSpace> <Type>OSPatchingForLinux</Type> <Version>2.3.1.0</Version> <Label>Microsoft Azure OS Patching Extension for Linux Virtual Machines</Label> <HostingResources>VmRole</HostingResources> <MediaLink></MediaLink> <Description>Microsoft Azure OS Patching Extension for Linux Virtual Machines</Description> <IsInternalExtension>true</IsInternalExtension> <Eula>https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt</Eula> <PrivacyUri>http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx</PrivacyUri> <HomepageUri>https://github.com/Azure/azure-linux-extensions</HomepageUri> <IsJsonExtension>true</IsJsonExtension> <SupportedOS>Linux</SupportedOS> <CompanyName>Microsoft</CompanyName> <!--%REGIONS%--> </ExtensionImage> ================================================ FILE: OSPatching/oneoff/__init__.py ================================================ ================================================ FILE: OSPatching/patch/AbstractPatching.py ================================================ #!/usr/bin/python # # AbstractPatching is the base patching class of all the linux distros # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import re import json import random import shutil import time import datetime import logging import logging.handlers from Utils.WAAgentUtil import waagent from ConfigOptions import ConfigOptions mfile = os.path.join(os.getcwd(), 'HandlerManifest.json') with open(mfile,'r') as f: manifest = json.loads(f.read())[0] Version = manifest['version'] StatusTest = { "Scheduled" : { "Idle" : None, "Healthy" : None }, "Oneoff" : { "Idle" : None, "Healthy" : None } } try: from scheduled.idleTest import is_vm_idle StatusTest["Scheduled"]["Idle"] = is_vm_idle except: pass try: from oneoff.idleTest import is_vm_idle StatusTest["Oneoff"]["Idle"] = is_vm_idle except: pass try: from scheduled.healthyTest import is_vm_healthy StatusTest["Scheduled"]["Healthy"] = is_vm_healthy except: pass try: from oneoff.healthyTest import is_vm_healthy StatusTest["Oneoff"]["Healthy"] = is_vm_healthy except: pass class AbstractPatching(object): """ AbstractPatching defines a skeleton neccesary for a concrete Patching class. """ def __init__(self, hutil): self.hutil = hutil self.syslogger = None self.patched = [] self.to_patch = [] self.downloaded = [] self.download_retry_queue = [] # Patching Configuration self.disabled = None self.stop = None self.reboot_after_patch = None self.category = None self.install_duration = None self.oneoff = None self.interval_of_weeks = None self.day_of_week = None self.start_time = None self.download_time = None self.download_duration = 3600 self.gap_between_stage = 60 self.current_configs = dict() self.category_required = ConfigOptions.category["required"] self.category_all = ConfigOptions.category["all"] # Crontab Variables self.crontab = '/etc/crontab' self.cron_restart_cmd = 'service cron restart' self.cron_chkconfig_cmd = 'chkconfig cron on' # Path Variables self.cwd = os.getcwd() self.package_downloaded_path = os.path.join(self.cwd, 'package.downloaded') self.package_patched_path = os.path.join(self.cwd, 'package.patched') self.stop_flag_path = os.path.join(self.cwd, 'StopOSPatching') self.history_scheduled = os.path.join(self.cwd, 'scheduled/history') self.scheduled_configs_file = os.path.join(self.cwd, 'scheduled/configs') self.dist_upgrade_list = None self.dist_upgrade_list_key = 'distUpgradeList' self.dist_upgrade_all = False self.dist_upgrade_all_key = 'distUpgradeAll' # Reboot Requirements self.reboot_required = False self.open_deleted_files_before = list() self.open_deleted_files_after = list() self.needs_restart = list() def is_string_none_or_empty(self, str): if str is None or len(str) < 1: return True return False def parse_settings(self, settings): disabled = settings.get("disabled") if disabled is None or str(disabled).lower() not in ConfigOptions.disabled: msg = "The value of parameter \"disabled\" is empty or invalid. Set it False by default." self.log_and_syslog(logging.WARNING, msg) self.disabled = False else: if str(disabled).lower() == "true": self.disabled = True else: self.disabled = False self.current_configs["disabled"] = str(self.disabled) if self.disabled: msg = "The extension is disabled." self.log_and_syslog(logging.WARNING, msg) return stop = settings.get("stop") if stop is None or str(stop).lower() not in ConfigOptions.stop: msg = "The value of parameter \"stop\" is empty or invalid. Set it False by default." self.log_and_syslog(logging.WARNING, msg) self.stop = False else: if str(stop).lower() == 'true': self.stop = True else: self.stop = False self.current_configs["stop"] = str(self.stop) reboot_after_patch = settings.get("rebootAfterPatch") if reboot_after_patch is None or reboot_after_patch.lower() not in ConfigOptions.reboot_after_patch: msg = "The value of parameter \"rebootAfterPatch\" is empty or invalid. Set it \"rebootifneed\" by default." self.log_and_syslog(logging.WARNING, msg) self.reboot_after_patch = ConfigOptions.reboot_after_patch[0] else: self.reboot_after_patch = reboot_after_patch.lower() waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Enable, isSuccess=True, version=Version, message="rebootAfterPatch="+self.reboot_after_patch) self.current_configs["rebootAfterPatch"] = self.reboot_after_patch category = settings.get('category') if category is None or category.lower() not in ConfigOptions.category.values(): msg = "The value of parameter \"category\" is empty or invalid. Set it " + self.category_required + " by default." self.log_and_syslog(logging.WARNING, msg) self.category = self.category_required else: self.category = category.lower() waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Enable, isSuccess=True, version=Version, message="category="+self.category) self.current_configs["category"] = self.category self.dist_upgrade_list = settings.get(self.dist_upgrade_list_key) if not self.is_string_none_or_empty(self.dist_upgrade_list): self.current_configs[self.dist_upgrade_list_key] = self.dist_upgrade_list dist_upgrade_all = settings.get(self.dist_upgrade_all_key) if dist_upgrade_all is None: msg = "The value of parameter \"{0}\" is empty or invalid. Set it false by default.".format(self.dist_upgrade_all_key) self.log_and_syslog(logging.INFO, msg) self.dist_upgrade_all = False elif str(dist_upgrade_all).lower() == 'true': self.dist_upgrade_all = True else: self.dist_upgrade_all = False self.current_configs[self.dist_upgrade_all_key] = str(self.dist_upgrade_all) check_hrmin = re.compile(r'^[0-9]{1,2}:[0-9]{1,2}$') install_duration = settings.get('installDuration') if install_duration is None or not re.match(check_hrmin, install_duration): msg = "The value of parameter \"installDuration\" is empty or invalid. Set it 1 hour by default." self.log_and_syslog(logging.WARNING, msg) self.install_duration = 3600 self.current_configs["installDuration"] = "01:00" else: hr_min = install_duration.split(':') self.install_duration = int(hr_min[0]) * 3600 + int(hr_min[1]) * 60 self.current_configs["installDuration"] = install_duration if self.install_duration <= 300: msg = "The value of parameter \"installDuration\" is smaller than 5 minutes. The extension will not reserve 5 minutes for reboot. It is recommended to set \"installDuration\" more than 30 minutes." self.log_and_syslog(logging.WARNING, msg) else: msg = "The extension will reserve 5 minutes for reboot." # 5 min for reboot self.install_duration -= 300 self.log_and_syslog(logging.INFO, msg) # The parameter "downloadDuration" is not exposed to users. So there's no log. download_duration = settings.get('downloadDuration') if download_duration is not None and re.match(check_hrmin, download_duration): hr_min = download_duration.split(':') self.download_duration = int(hr_min[0]) * 3600 + int(hr_min[1]) * 60 oneoff = settings.get('oneoff') if oneoff is None or str(oneoff).lower() not in ConfigOptions.oneoff: msg = "The value of parameter \"oneoff\" is empty or invalid. Set it False by default." self.log_and_syslog(logging.WARNING, msg) self.oneoff = False else: if str(oneoff).lower() == "true": self.oneoff = True msg = "The extension will run in one-off mode." else: self.oneoff = False msg = "The extension will run in scheduled task mode." self.log_and_syslog(logging.INFO, msg) self.current_configs["oneoff"] = str(self.oneoff) if not self.oneoff: start_time = settings.get('startTime') if start_time is None or not re.match(check_hrmin, start_time): msg = "The parameter \"startTime\" is empty or invalid. It defaults to 03:00." self.log_and_syslog(logging.WARNING, msg) start_time = "03:00" try: start_time_dt = datetime.datetime.strptime(start_time, '%H:%M') self.start_time = datetime.time(start_time_dt.hour, start_time_dt.minute) except ValueError: msg = "The parameter \"startTime\" is invalid. It defaults to 03:00." self.log_and_syslog(logging.WARNING, msg) self.start_time = datetime.time(3) download_time_dt = start_time_dt - datetime.timedelta(seconds=self.download_duration) self.download_time = datetime.time(download_time_dt.hour, download_time_dt.minute) self.current_configs["startTime"] = start_time day_of_week = settings.get("dayOfWeek") if day_of_week is None or day_of_week == "": msg = "The parameter \"dayOfWeek\" is empty. dayOfWeek defaults to Everyday." self.log_and_syslog(logging.WARNING, msg) day_of_week = "everyday" self.day_of_week = ConfigOptions.day_of_week["everyday"] else: for day in day_of_week.split('|'): day = day.strip().lower() if day not in ConfigOptions.day_of_week: msg = "The parameter \"dayOfWeek\" is invalid. dayOfWeek defaults to Everyday." self.log_and_syslog(logging.WARNING, msg) day_of_week = "everyday" break if "everyday" in day_of_week: self.day_of_week = ConfigOptions.day_of_week["everyday"] else: self.day_of_week = [ConfigOptions.day_of_week[day.strip().lower()] for day in day_of_week.split('|')] waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Enable, isSuccess=True, version=Version, message="dayOfWeek=" + day_of_week) self.current_configs["dayOfWeek"] = day_of_week interval_of_weeks = settings.get('intervalOfWeeks') if interval_of_weeks is None or interval_of_weeks not in ConfigOptions.interval_of_weeks: msg = "The parameter \"intervalOfWeeks\" is empty or invalid. intervalOfWeeks defaults to 1." self.log_and_syslog(logging.WARNING, msg) self.interval_of_weeks = '1' else: self.interval_of_weeks = interval_of_weeks waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Enable, isSuccess=True, version=Version, message="intervalOfWeeks="+self.interval_of_weeks) self.current_configs["intervalOfWeeks"] = self.interval_of_weeks # Save the latest configuration for scheduled task to avoid one-off mode's affection waagent.SetFileContents(self.scheduled_configs_file, json.dumps(self.current_configs)) msg = "Current Configuration: " + self.get_current_config() self.log_and_syslog(logging.INFO, msg) def install(self): pass def enable(self): if self.stop: self.stop_download() self.create_stop_flag() return self.delete_stop_flag() if not self.disabled and self.oneoff: script_file_path = os.path.realpath(sys.argv[0]) os.system(' '.join(['python', script_file_path, '-oneoff', '>/dev/null 2>&1 &'])) else: waagent.SetFileContents(self.history_scheduled, '') self.set_download_cron() self.set_patch_cron() self.restart_cron() def disable(self): self.disabled = True self.enable() def stop_download(self): ''' kill the process of downloading and its subprocess. return code: 100 - There are no downloading process to stop 0 - The downloading process is stopped ''' script_file_path = os.path.realpath(sys.argv[0]) script_file = os.path.basename(script_file_path) retcode, output = waagent.RunGetOutput('ps -ef | grep "' + script_file + ' -download" | grep -v grep | grep -v sh | awk \'{print $2}\'') if retcode > 0: self.log_and_syslog(logging.ERROR, output) if output != '': retcode, output2 = waagent.RunGetOutput("ps -ef | awk '{if($3==" + output.strip() + ") {print $2}}'") if retcode > 0: self.log_and_syslog(logging.ERROR, output2) if output2 != '': waagent.Run('kill -9 ' + output2.strip()) waagent.Run('kill -9 ' + output.strip()) return 0 return 100 def set_download_cron(self): script_file_path = os.path.realpath(sys.argv[0]) script_dir = os.path.dirname(script_file_path) script_file = os.path.basename(script_file_path) old_line_end = ' '.join([script_file, '-download']) if self.disabled: new_line = '\n' else: if self.download_time > self.start_time: dow = ','.join([str((day - 1) % 7) for day in self.day_of_week]) else: dow = ','.join([str(day % 7) for day in self.day_of_week]) hr = str(self.download_time.hour) minute = str(self.download_time.minute) new_line = ' '.join(['\n' + minute, hr, '* *', dow, 'root cd', script_dir, '&& python check.py', self.interval_of_weeks, '&& python', script_file, '-download > /dev/null 2>&1\n']) waagent.ReplaceFileContentsAtomic(self.crontab, '\n'.join(filter(lambda a: a and (old_line_end not in a), waagent.GetFileContents(self.crontab).split('\n'))) + new_line) def set_patch_cron(self): script_file_path = os.path.realpath(sys.argv[0]) script_dir = os.path.dirname(script_file_path) script_file = os.path.basename(script_file_path) old_line_end = ' '.join([script_file, '-patch']) if self.disabled: new_line = '\n' else: start_time_dt = datetime.datetime(100, 1, 1, self.start_time.hour, self.start_time.minute) start_hr = str(self.start_time.hour) start_minute = str(self.start_time.minute) start_dow = ','.join([str(day % 7) for day in self.day_of_week]) cleanup_time_dt = start_time_dt + datetime.timedelta(minutes=1) cleanup_hr = str(cleanup_time_dt.hour) cleanup_minute = str(cleanup_time_dt.minute) if start_time_dt.day < cleanup_time_dt.day: cleanup_dow = ','.join([str((day + 1) % 7) for day in self.day_of_week]) else: cleanup_dow = ','.join([str(day % 7) for day in self.day_of_week]) new_line = ' '.join(['\n' + start_minute, start_hr, '* *', start_dow, 'root cd', script_dir, '&& python check.py', self.interval_of_weeks, '&& python', script_file, '-patch >/dev/null 2>&1\n']) new_line += ' '.join([cleanup_minute, cleanup_hr, '* *', cleanup_dow, 'root rm -f', self.stop_flag_path, '\n']) waagent.ReplaceFileContentsAtomic(self.crontab, "\n".join(filter(lambda a: a and (old_line_end not in a) and (self.stop_flag_path not in a), waagent.GetFileContents(self.crontab).split('\n'))) + new_line) def restart_cron(self): retcode,output = waagent.RunGetOutput(self.cron_restart_cmd) if retcode > 0: self.log_and_syslog(logging.ERROR, output) def download(self): # Read the latest configuration for scheduled task settings = json.loads(waagent.GetFileContents(self.scheduled_configs_file)) self.parse_settings(settings) self.provide_vm_status_test(StatusTest["Scheduled"]) if not self.check_vm_idle(StatusTest["Scheduled"]): return if self.exists_stop_flag(): self.log_and_syslog(logging.INFO, "Downloading patches is stopped/canceled") return waagent.SetFileContents(self.package_downloaded_path, '') waagent.SetFileContents(self.package_patched_path, '') start_download_time = time.time() # Installing security patches is mandatory self._download(self.category_required) if self.category == self.category_all: self._download(self.category_all) self.retry_download() end_download_time = time.time() waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Download, isSuccess=True, version=Version, message=" ".join(["Real downloading time is", str(round(end_download_time-start_download_time,3)), "s"])) def _download(self, category): self.log_and_syslog(logging.INFO, "Start to check&download patches (Category:" + category + ")") retcode, downloadlist = self.check(category) if retcode > 0: msg = "Failed to check valid upgrades" self.log_and_syslog(logging.ERROR, msg) self.hutil.do_exit(1, 'Enable', 'error', '0', msg) if 'walinuxagent' in downloadlist: downloadlist.remove('walinuxagent') if not downloadlist: self.log_and_syslog(logging.INFO, "No packages are available for update.") return self.log_and_syslog(logging.INFO, "There are " + str(len(downloadlist)) + " packages to upgrade.") self.log_and_syslog(logging.INFO, "Download list: " + ' '.join(downloadlist)) for pkg_name in downloadlist: if pkg_name in self.downloaded: continue retcode = self.download_package(pkg_name) if retcode != 0: self.log_and_syslog(logging.ERROR, "Failed to download the package: " + pkg_name) self.log_and_syslog(logging.INFO, "Put {0} into a retry queue".format(pkg_name)) self.download_retry_queue.append((pkg_name, category)) continue self.downloaded.append(pkg_name) self.log_and_syslog(logging.INFO, "Package " + pkg_name + " is downloaded.") waagent.AppendFileContents(self.package_downloaded_path, pkg_name + ' ' + category + '\n') def retry_download(self): retry_count = 0 max_retry_count = 12 self.log_and_syslog(logging.INFO, "Retry queue: {0}".format( " ".join([pkg_name for pkg_name,category in self.download_retry_queue]))) while self.download_retry_queue: pkg_name, category = self.download_retry_queue[0] self.download_retry_queue = self.download_retry_queue[1:] retcode = self.download_package(pkg_name) if retcode == 0: self.downloaded.append(pkg_name) self.log_and_syslog(logging.INFO, "Package " + pkg_name + " is downloaded.") waagent.AppendFileContents(self.package_downloaded_path, pkg_name + ' ' + category + '\n') else: self.log_and_syslog(logging.ERROR, "Failed to download the package: " + pkg_name) self.log_and_syslog(logging.INFO, "Put {0} back into a retry queue".format(pkg_name)) self.download_retry_queue.append((pkg_name,category)) retry_count = retry_count + 1 if retry_count > max_retry_count: err_msg = ("Failed to download after {0} retries, " "retry queue: {1}").format(max_retry_count, " ".join([pkg_name for pkg_name,category in self.download_retry_queue])) self.log_and_syslog(logging.ERROR, err_msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Download, isSuccess=False, version=Version, message=err_msg) break k = retry_count if (retry_count < 10) else 10 interval = int(random.uniform(0, 2 ** k)) self.log_and_syslog(logging.INFO, ("Sleep {0}s before " "the next retry, current retry_count = {1}").format(interval, retry_count)) time.sleep(interval) def patch(self): # Read the latest configuration for scheduled task settings = json.loads(waagent.GetFileContents(self.scheduled_configs_file)) self.parse_settings(settings) if not self.check_vm_idle(StatusTest["Scheduled"]): return if self.exists_stop_flag(): self.log_and_syslog(logging.INFO, "Installing patches is stopped/canceled") self.delete_stop_flag() return # Record the scheduled time waagent.AppendFileContents(self.history_scheduled, time.strftime("%Y-%m-%d %a", time.localtime()) + '\n' ) # Record the open deleted files before patching self.open_deleted_files_before = self.check_open_deleted_files() retcode = self.stop_download() if retcode == 0: self.log_and_syslog(logging.WARNING, "Download time exceeded. The pending package will be downloaded in the next cycle") waagent.AddExtensionEvent(name=self.hutil.get_name(), op=waagent.WALAEventOperation.Download, isSuccess=False, version=Version, message="Downloading time out") global start_patch_time start_patch_time = time.time() pkg_failed = [] is_time_out = [False, False] patchlist = self.get_pkg_to_patch(self.category_required) is_time_out[0],failed = self._patch(self.category_required, patchlist) pkg_failed.extend(failed) if not self.exists_stop_flag(): if not is_time_out[0]: patchlist = self.get_pkg_to_patch(self.category_all) if len(patchlist) == 0: self.log_and_syslog(logging.INFO, "No packages are available for update. (Category:" + self.category_all + ")") else: self.log_and_syslog(logging.INFO, "Going to sleep for " + str(self.gap_between_stage) + "s") time.sleep(self.gap_between_stage) is_time_out[1],failed = self._patch(self.category_all, patchlist) pkg_failed.extend(failed) else: msg = "Installing patches (Category:" + self.category_all + ") is stopped/canceled" self.log_and_syslog(logging.INFO, msg) if is_time_out[0] or is_time_out[1]: msg = "Patching time out" self.log_and_syslog(logging.WARNING, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Patch", isSuccess=False, version=Version, message=msg) self.open_deleted_files_after = self.check_open_deleted_files() self.delete_stop_flag() #self.report() if StatusTest["Scheduled"]["Healthy"]: is_healthy = StatusTest["Scheduled"]["Healthy"]() msg = "Checking the VM is healthy after patching: " + str(is_healthy) self.log_and_syslog(logging.INFO, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Check healthy", isSuccess=is_healthy, version=Version, message=msg) if self.patched is not None and len(self.patched) > 0: self.reboot_if_required() def _patch(self, category, patchlist): if self.exists_stop_flag(): self.log_and_syslog(logging.INFO, "Installing patches (Category:" + category + ") is stopped/canceled") return False,list() if not patchlist: self.log_and_syslog(logging.INFO, "No packages are available for update.") return False,list() self.log_and_syslog(logging.INFO, "Start to install " + str(len(patchlist)) +" patches (Category:" + category + ")") self.log_and_syslog(logging.INFO, "Patch list: " + ' '.join(patchlist)) pkg_failed = [] for pkg_name in patchlist: if pkg_name == 'walinuxagent': continue current_patch_time = time.time() if current_patch_time - start_patch_time > self.install_duration: msg = "Patching time exceeded. The pending package will be patched in the next cycle" self.log_and_syslog(logging.WARNING, msg) return True,pkg_failed retcode = self.patch_package(pkg_name) if retcode != 0: self.log_and_syslog(logging.ERROR, "Failed to patch the package:" + pkg_name) pkg_failed.append(' '.join([pkg_name, category])) continue self.patched.append(pkg_name) self.log_and_syslog(logging.INFO, "Package " + pkg_name + " is patched.") waagent.AppendFileContents(self.package_patched_path, pkg_name + ' ' + category + '\n') return False,pkg_failed def patch_one_off(self): """ Called when startTime is empty string, which means a on-demand patch. """ self.provide_vm_status_test(StatusTest["Oneoff"]) if not self.check_vm_idle(StatusTest["Oneoff"]): return global start_patch_time start_patch_time = time.time() self.log_and_syslog(logging.INFO, "Going to patch one-off") waagent.SetFileContents(self.package_downloaded_path, '') waagent.SetFileContents(self.package_patched_path, '') # Record the open deleted files before patching self.open_deleted_files_before = self.check_open_deleted_files() pkg_failed = [] is_time_out = [False, False] retcode, patchlist_required = self.check(self.category_required) if retcode > 0: msg = "Failed to check valid upgrades" self.log_and_syslog(logging.ERROR, msg) self.hutil.do_exit(1, 'Enable', 'error', '0', msg) if not patchlist_required: self.log_and_syslog(logging.INFO, "No packages are available for update. (Category:" + self.category_required + ")") else: is_time_out[0],failed = self._patch(self.category_required, patchlist_required) pkg_failed.extend(failed) if self.category == self.category_all: if not self.exists_stop_flag(): if not is_time_out[0]: retcode, patchlist_other = self.check(self.category_all) if retcode > 0: msg = "Failed to check valid upgrades" self.log_and_syslog(logging.ERROR, msg) self.hutil.do_exit(1, 'Enable', 'error', '0', msg) patchlist_other = [pkg for pkg in patchlist_other if pkg not in patchlist_required] if len(patchlist_other) == 0: self.log_and_syslog(logging.INFO, "No packages are available for update. (Category:" + self.category_all + ")") else: self.log_and_syslog(logging.INFO, "Going to sleep for " + str(self.gap_between_stage) + "s") time.sleep(self.gap_between_stage) self.log_and_syslog(logging.INFO, "Going to patch one-off (Category:" + self.category_all + ")") is_time_out[1],failed = self._patch(self.category_all, patchlist_other) pkg_failed.extend(failed) else: self.log_and_syslog(logging.INFO, "Installing patches (Category:" + self.category_all + ") is stopped/canceled") if is_time_out[0] or is_time_out[1]: waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Oneoff Patch", isSuccess=False, version=Version, message="Patching time out") shutil.copy2(self.package_patched_path, self.package_downloaded_path) for pkg in pkg_failed: waagent.AppendFileContents(self.package_downloaded_path, pkg + '\n') self.open_deleted_files_after = self.check_open_deleted_files() self.delete_stop_flag() #self.report() if StatusTest["Oneoff"]["Healthy"]: is_healthy = StatusTest["Oneoff"]["Healthy"]() msg = "Checking the VM is healthy after patching: " + str(is_healthy) self.log_and_syslog(logging.INFO, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Check healthy", isSuccess=is_healthy, version=Version, message=msg) if self.patched is not None and len(self.patched) > 0: self.reboot_if_required() def reboot_if_required(self): self.check_reboot() self.check_needs_restart() msg = '' if self.reboot_after_patch == 'notrequired' and self.reboot_required: msg += 'Pending Reboot' if self.needs_restart: msg += ': ' + ' '.join(self.needs_restart) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Reboot", isSuccess=False, version=Version, message=" ".join([self.reboot_after_patch, msg, str(len(self.needs_restart)), "packages need to restart"])) self.hutil.do_exit(0, 'Enable', 'success', '0', msg) if self.reboot_after_patch == 'required': msg += "System going to reboot(Required)" elif self.reboot_after_patch == 'auto' and self.reboot_required: msg += "System going to reboot(Auto)" elif self.reboot_after_patch == 'rebootifneed': if (self.reboot_required or self.needs_restart): msg += "System going to reboot(RebootIfNeed)" if msg: if self.needs_restart: msg += ': ' + ' '.join(self.needs_restart) self.log_and_syslog(logging.INFO, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Reboot", isSuccess=True, version=Version, message="Reboot") retcode = waagent.Run('reboot') if retcode != 0: self.log_and_syslog(logging.ERROR, "Failed to reboot") waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Reboot", isSuccess=False, version=Version, message="Failed to reboot") else: waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Reboot", isSuccess=False, version=Version, message="Not reboot") def check_needs_restart(self): self.needs_restart.extend(self.get_pkg_needs_restart()) patched_files = dict() for pkg in self.get_pkg_patched(): cmd = ' '.join([self.pkg_query_cmd, pkg]) try: retcode, output = waagent.RunGetOutput(cmd) patched_files[os.path.basename(pkg)] = [filename for filename in output.split("\n") if os.path.isfile(filename)] except Exception: self.log_and_syslog(logging.ERROR, "Failed to " + cmd) # for k,v in patched_files.items(): # self.log_and_syslog(logging.INFO, k + ": " + " ".join(v)) open_deleted_files = list() for filename in self.open_deleted_files_after: if filename not in self.open_deleted_files_before: open_deleted_files.append(filename) # self.log_and_syslog(logging.INFO, "Open deleted files: " + " ".join(open_deleted_files)) for pkg,files in patched_files.items(): for filename in files: realpath = os.path.realpath(filename) if realpath in open_deleted_files and pkg not in self.needs_restart: self.needs_restart.append(pkg) msg = "Packages needs to restart: " pkgs = " ".join(self.needs_restart) if pkgs: msg += pkgs else: msg = "There is no package which needs to restart" self.log_and_syslog(logging.INFO, msg) def get_pkg_needs_restart(self): return [] def check_open_deleted_files(self): ret = list() retcode,output = waagent.RunGetOutput('lsof | grep "DEL"') if retcode == 0: for line in output.split('\n'): if line: filename = line.split()[-1] if filename not in ret: ret.append(filename) return ret def create_stop_flag(self): waagent.SetFileContents(self.stop_flag_path, '') def delete_stop_flag(self): if self.exists_stop_flag(): os.remove(self.stop_flag_path) def exists_stop_flag(self): if os.path.isfile(self.stop_flag_path): return True else: return False def get_pkg_to_patch(self, category): if not os.path.isfile(self.package_downloaded_path): return [] pkg_to_patch = waagent.GetFileContents(self.package_downloaded_path) if not pkg_to_patch: return [] patchlist = [line.split()[0] for line in pkg_to_patch.split('\n') if line.endswith(category)] if patchlist is None: return [] return patchlist def get_pkg_patched(self): if not os.path.isfile(self.package_patched_path): return [] pkg_patched = waagent.GetFileContents(self.package_patched_path) if not pkg_patched: return [] patchedlist = [line.split()[0] for line in pkg_patched.split('\n') if line] return patchedlist def get_current_config(self): current_configs = [] for k,v in self.current_configs.items(): current_configs.append(k + "=" + v) return ",".join(current_configs) def provide_vm_status_test(self, status_test): for status,provided in status_test.items(): if provided is None: provided = "False" level = logging.WARNING else: provided = "True" level = logging.INFO msg = "The VM %s test script is provided: %s" % (status, provided) self.log_and_syslog(level, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="provides %s test script" % (status,), isSuccess=provided, version=Version, message=msg) def check_vm_idle(self, status_test): is_idle = True if status_test["Idle"]: is_idle = status_test["Idle"]() msg = "Checking the VM is idle: " + str(is_idle) self.log_and_syslog(logging.INFO, msg) waagent.AddExtensionEvent(name=self.hutil.get_name(), op="Check idle", isSuccess=is_idle, version=Version, message=msg) if not is_idle: self.log_and_syslog(logging.WARNING, "Current Operation is skipped.") return is_idle def log_and_syslog(self, level, message): if level == logging.INFO: self.hutil.log(message) elif level == logging.WARNING: self.hutil.log(" ".join(["Warning:", message])) elif level == logging.ERROR: self.hutil.error(message) if self.syslogger is None: self.init_syslog() self.syslog(level, message) def init_syslog(self): self.syslogger = logging.getLogger(self.hutil.get_name()) self.syslogger.setLevel(logging.INFO) formatter = logging.Formatter('%(name)s: %(levelname)s %(message)s') try: handler = logging.handlers.SysLogHandler(address='/dev/log') handler.setFormatter(formatter) self.syslogger.addHandler(handler) except: self.syslogger = None self.hutil.error("Syslog is not ready.") def syslog(self, level, message): if self.syslogger is None: return if level == logging.INFO: self.syslogger.info(message) elif level == logging.WARNING: self.syslogger.warning(message) elif level == logging.ERROR: self.syslogger.error(message) ================================================ FILE: OSPatching/patch/ConfigOptions.py ================================================ #!/usr/bin/python # # AbstractPatching is the base patching class of all the linux distros # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class ConfigOptions(object): disabled = ["true", "false"] # Default value is "false" stop = ["true", "false"] # Default value is "false" reboot_after_patch = ["rebootifneed", # Default value is "rebootifneed" "auto", "required", "notrequired"] category = {"required" : "important", # Default value is "important" "all" : "importantandrecommended"} oneoff = ["true", "false"] # Default value is "false" interval_of_weeks = [str(i) for i in range(1, 53)] # Default value is "1" day_of_week = {"everyday" : range(1,8), # Default value is "everyday" "monday" : 1, "tuesday" : 2, "wednesday": 3, "thursday" : 4, "friday" : 5, "saturday" : 6, "sunday" : 7} ================================================ FILE: OSPatching/patch/OraclePatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from redhatPatching import redhatPatching class OraclePatching(redhatPatching): def __init__(self, hutil): super(OraclePatching,self).__init__(hutil) ================================================ FILE: OSPatching/patch/SuSEPatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import shutil from Utils.WAAgentUtil import waagent from AbstractPatching import AbstractPatching class SuSEPatching(AbstractPatching): def __init__(self, hutil): super(SuSEPatching,self).__init__(hutil) self.patched_pkgs = None self.cache_dir = os.path.join(os.path.dirname(sys.argv[0]), 'packages') if not os.path.isdir(self.cache_dir): os.mkdir(self.cache_dir) self.clean_cmd = 'zypper clean' self.check_cmd = 'zypper -q --gpg-auto-import-keys --non-interactive list-patches' self.check_security_cmd = self.check_cmd + ' --category security' self.download_cmd = 'zypper --non-interactive --pkg-cache-dir ' + self.cache_dir + ' install -d --auto-agree-with-licenses -t patch ' self.patch_cmd = 'zypper --non-interactive --pkg-cache-dir ' + self.cache_dir + ' install --auto-agree-with-licenses -t patch ' self.pkg_query_cmd = 'rpm -qlp' waagent.Run('zypper -q --gpg-auto-import-keys --non-interactive refresh', False) def check(self, category): """ Check valid upgrades, Return the package list to upgrade """ if category == self.category_all: check_cmd = self.check_cmd elif category == self.category_required: check_cmd = self.check_security_cmd retcode, output = waagent.RunGetOutput(check_cmd) output_lines = output.split('\n') patch_list = [] name_position = 1 for line in output_lines: properties = [elem.strip() for elem in line.split('|')] if len(properties) > 1: if 'Name' in properties: name_position = properties.index('Name') elif not properties[name_position] in self.to_patch: patch_list.append(properties[name_position]) return retcode, patch_list def download_package(self, package): retcode = waagent.Run(self.download_cmd + package, False) if 0 < retcode and retcode < 100: return 1 else: return 0 def patch_package(self, package): if self.patched_pkgs == None: self.patched_pkgs = list() for root,dirs,files in os.walk(self.cache_dir): for filename in files: if filename.endswith('rpm'): shutil.copy(os.path.join(root, filename), "/tmp/") self.patched_pkgs.append("/tmp/"+filename) retcode = waagent.Run(self.patch_cmd + package, False) if 0 < retcode and retcode < 100: return 1 else: if retcode == 102: self.reboot_required = True return 0 def check_reboot(self): pass def get_pkg_patched(self): return self.patched_pkgs ================================================ FILE: OSPatching/patch/UbuntuPatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import logging from Utils.WAAgentUtil import waagent from AbstractPatching import AbstractPatching class UbuntuPatching(AbstractPatching): def __init__(self, hutil): super(UbuntuPatching,self).__init__(hutil) self.update_cmd = 'apt-get update' self.check_cmd = 'apt-get -qq -s upgrade' self.check_cmd_distupgrade = 'apt-get -qq -s dist-upgrade' self.check_security_suffix = ' -o Dir::Etc::SourceList=/etc/apt/security.sources.list' waagent.Run('grep "-security" /etc/apt/sources.list | sudo grep -v "#" > /etc/apt/security.sources.list') self.download_cmd = 'apt-get -d -y install' self.patch_cmd = 'apt-get -y -q --force-yes -o Dpkg::Options::="--force-confdef" install' self.fix_cmd = 'dpkg --configure -a --force-confdef' self.status_cmd = 'apt-cache show' self.pkg_query_cmd = 'dpkg-query -L' # Avoid a config prompt os.environ['DEBIAN_FRONTEND']='noninteractive' def install(self): """ Install for dependencies. """ # Update source.list waagent.Run(self.update_cmd, False) # /var/run/reboot-required is not created unless the update-notifier-common package is installed retcode = waagent.Run('apt-get -y install update-notifier-common') if retcode > 0: self.hutil.error("Failed to install update-notifier-common") def try_package_with_autofix(self, cmd): retcode, output = waagent.RunGetOutput(cmd) if retcode == 0: return retcode, output # An error occurred while running the command. Try to recover. # Unfortunately apt-get returns code 100 regardless of the error encountered, # so we can't smartly detect the cause of failure self.log_and_syslog(logging.WARNING, "Error running command ({0}). Will try to correct package state ({1}). Error was {2}".format(cmd, self.fix_cmd, output)) retcode, output = waagent.RunGetOutput(self.fix_cmd) if retcode != 0: self.log_and_syslog(logging.WARNING, "Error correcting package state ({0}). Error was {1}".format(self.fix_cmd, output)) retcode, output = waagent.RunGetOutput(cmd) if retcode != 0: self.log_and_syslog(logging.WARNING, "Unable to run ({0}) on second attempt. Giving up. Error was {1}".format(cmd, output)) return retcode, output def check(self, category): """ Check valid upgrades, Return the package list to download & upgrade """ # Perform upgrade or dist-upgrade as appropriate if self.dist_upgrade_all: self.log_and_syslog(logging.INFO, "Performing dist-upgrade for ALL packages") check_cmd = self.check_cmd_distupgrade else: check_cmd = self.check_cmd # If upgrading only required/security patches, append the command suffix # Otherwise, assume all packages will be upgraded if category == self.category_required: check_cmd = check_cmd + self.check_security_suffix retcode, output = self.try_package_with_autofix(check_cmd) to_download = [line.split()[1] for line in output.split('\n') if line.startswith('Inst')] # Azure repo assumes upgrade may have dependency changes if retcode != 0: self.log_and_syslog(logging.WARNING, "Failed to get list of upgradeable packages") elif self.is_string_none_or_empty(self.dist_upgrade_list): self.log_and_syslog(logging.INFO, "Dist upgrade list not specified, will perform normal patch") elif not os.path.isfile(self.dist_upgrade_list): self.log_and_syslog(logging.WARNING, "Dist upgrade list was specified but file [{0}] does not exist".format(self.dist_upgrade_list)) else: self.log_and_syslog(logging.INFO, "Running dist-upgrade using {0}".format(self.dist_upgrade_list)) self.check_azure_cmd = 'apt-get -qq -s dist-upgrade -o Dir::Etc::SourceList={0}'.format(self.dist_upgrade_list) retcode, azoutput = self.try_package_with_autofix(self.check_azure_cmd) azure_to_download = [line.split()[1] for line in azoutput.split('\n') if line.startswith('Inst')] to_download += list(set(azure_to_download) - set(to_download)) return retcode, to_download def download_package(self, package): return waagent.Run(self.download_cmd + ' ' + package) def patch_package(self, package): retcode, output = self.try_package_with_autofix(self.patch_cmd + ' ' + package) return retcode def check_reboot(self): self.reboot_required = os.path.isfile('/var/run/reboot-required') def get_pkg_needs_restart(self): fd = '/var/run/reboot-required.pkgs' if not os.path.isfile(fd): return [] return waagent.GetFileContents(fd).split('\n') def report(self): """ TODO: Report the detail status of patching """ for package_patched in self.patched: retcode,output = waagent.RunGetOutput(self.status_cmd + ' ' + package_patched) output = output.split('\n\n')[0] self.hutil.log(output) ================================================ FILE: OSPatching/patch/__init__.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import re import platform from UbuntuPatching import UbuntuPatching from redhatPatching import redhatPatching from centosPatching import centosPatching from OraclePatching import OraclePatching from SuSEPatching import SuSEPatching # Define the function in case waagent(<2.0.4) doesn't have DistInfo() def DistInfo(fullname=0): if 'FreeBSD' in platform.system(): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if os.path.isfile('/etc/oracle-release'): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['Oracle', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(\ full_distribution_name=fullname)) # remove trailing whitespace in distro name distinfo[0] = distinfo[0].strip() return distinfo else: return platform.dist() def GetMyPatching(hutil, patching_class_name=''): """ Return MyPatching object. NOTE: Logging is not initialized at this point. """ if patching_class_name == '': if 'Linux' in platform.system(): Distro = DistInfo()[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() Distro = Distro.strip('"') Distro = Distro.strip(' ') patching_class_name = Distro + 'Patching' else: Distro = patching_class_name if not globals().has_key(patching_class_name): hutil.log_and_syslog(Distro + ' is not a supported distribution.') return None return globals()[patching_class_name](hutil) ================================================ FILE: OSPatching/patch/centosPatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from redhatPatching import redhatPatching class centosPatching(redhatPatching): def __init__(self, hutil): super(centosPatching,self).__init__(hutil) ================================================ FILE: OSPatching/patch/redhatPatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from Utils.WAAgentUtil import waagent from AbstractPatching import AbstractPatching class redhatPatching(AbstractPatching): def __init__(self, hutil): super(redhatPatching,self).__init__(hutil) self.cron_restart_cmd = 'service crond restart' self.check_cmd = 'yum -q check-update' self.check_security_cmd = 'yum -q --security check-update' self.clean_cmd = 'yum clean packages' self.download_cmd = 'yum -q -y --downloadonly update' self.patch_cmd = 'yum -y update' self.status_cmd = 'yum -q info' self.pkg_query_cmd = 'repoquery -l' self.cache_dir = '/var/cache/yum/' def install(self): """ Install for dependencies. """ # For yum --downloadonly option waagent.Run('yum -y install yum-downloadonly', False) # For yum --security option retcode = waagent.Run('yum -y install yum-plugin-security') if retcode > 0: self.hutil.error("Failed to install yum-plugin-security") # For package-cleanup, needs-restarting, repoquery retcode = waagent.Run('yum -y install yum-utils') if retcode > 0: self.hutil.error("Failed to install yum-utils") # For lsof retcode = waagent.Run('yum -y install lsof') if retcode > 0: self.hutil.error("Failed to install lsof") # Install missing dependencies missing_dependency_list = self.check_missing_dependencies() for pkg in missing_dependency_list: retcode = waagent.Run('yum -y install ' + pkg) if retcode > 0: self.hutil.error("Failed to install missing dependency: " + pkg) def check(self, category): """ Check valid upgrades, Return the package list to download & upgrade """ if category == self.category_all: check_cmd = self.check_cmd elif category == self.category_required: check_cmd = self.check_security_cmd to_download = [] retcode,output = waagent.RunGetOutput(check_cmd, chk_err=False) if retcode == 0: return 0, to_download elif retcode == 100: lines = output.strip().split('\n') for line in lines: line = re.split(r'\s+', line.strip()) if len(line) != 3: break to_download.append(line[0]) return 0, to_download elif retcode == 1: return 1, to_download def download_package(self, package): retcode = waagent.Run(self.download_cmd + ' ' + package, chk_err=False) # Yum exit code is not 0 even if succeed, so check if the package rpm exsits to verify that downloading succeeds. return self.check_download(package) def patch_package(self, package): return waagent.Run(self.patch_cmd + ' ' + package) def check_reboot(self): retcode,last_kernel = waagent.RunGetOutput("rpm -q --last kernel") last_kernel = last_kernel.split()[0][7:] retcode,current_kernel = waagent.RunGetOutput('uname -r') current_kernel = current_kernel.strip() self.reboot_required = (last_kernel != current_kernel) def report(self): """ TODO: Report the detail status of patching """ for package_patched in self.patched: self.info_pkg(package_patched) def info_pkg(self, pkg_name): """ Return details about a package """ retcode,output = waagent.RunGetOutput(self.status_cmd + ' ' + pkg_name) if retcode != 0: self.hutil.error(output) return None installed_pkg_info_list = output.rpartition('Available Packages')[0].strip().split('\n') available_pkg_info_list = output.rpartition('Available Packages')[-1].strip().split('\n') pkg_info = dict() pkg_info['installed'] = dict() pkg_info['available'] = dict() for item in installed_pkg_info_list: if item.startswith('Name'): pkg_info['installed']['name'] = item.split(':')[-1].strip() elif item.startswith('Arch'): pkg_info['installed']['arch'] = item.split(':')[-1].strip() elif item.startswith('Version'): pkg_info['installed']['version'] = item.split(':')[-1].strip() elif item.startswith('Release'): pkg_info['installed']['release'] = item.split(':')[-1].strip() for item in available_pkg_info_list: if item.startswith('Name'): pkg_info['available']['name'] = item.split(':')[-1].strip() elif item.startswith('Arch'): pkg_info['available']['arch'] = item.split(':')[-1].strip() elif item.startswith('Version'): pkg_info['available']['version'] = item.split(':')[-1].strip() elif item.startswith('Release'): pkg_info['available']['release'] = item.split(':')[-1].strip() return pkg_info def check_download(self, pkg_name): pkg_info = self.info_pkg(pkg_name) name = pkg_info['available']['name'] arch = pkg_info['available']['arch'] version = pkg_info['available']['version'] release = pkg_info['available']['release'] package = '.'.join(['-'.join([name, version, release]), arch, 'rpm']) retcode,output = waagent.RunGetOutput('cd ' + self.cache_dir + ';find . -name "'+ package + '"') if retcode != 0: self.hutil.error("Unable to check whether the downloading secceeds") else: if output == '': return 1 else: return 0 def check_missing_dependencies(self): retcode, output = waagent.RunGetOutput('package-cleanup --problems', chk_err=False) missing_dependency_list = [] for line in output.split('\n'): if 'requires' not in line: continue words = line.split() missing_dependency = words[words.index('requires') + 1] if missing_dependency not in missing_dependency_list: missing_dependency_list.append(missing_dependency) return missing_dependency_list ================================================ FILE: OSPatching/references ================================================ Utils/ ================================================ FILE: OSPatching/scheduled/__init__.py ================================================ ================================================ FILE: OSPatching/scheduled/history ================================================ ================================================ FILE: OSPatching/test/FakePatching.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from AbstractPatching import AbstractPatching sys.path.append('../patch') class FakePatching(AbstractPatching): def __init__(self, hutil=None): super(FakePatching,self).__init__(hutil) self.pkg_query_cmd = 'dpkg-query -L' self.gap_between_stage = 1 self.download_duration = 3600 self.security_download_list = ['a', 'b', 'c', 'd', 'e'] self.all_download_list = ['1', '2', '3', '4', 'a', 'b', 'c', 'd', 'e'] def install(self): """ Install for dependencies. """ pass def check(self, category): """ Check valid upgrades, Return the package list to download & upgrade """ if category == 'important': return 0, self.security_download_list else: return 0, self.all_download_list def download_package(self, package): return 0 def patch_package(self, package): return 0 def check_reboot(self): return False ================================================ FILE: OSPatching/test/FakePatching2.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import time from AbstractPatching import AbstractPatching sys.path.append('../patch') class FakePatching(AbstractPatching): def __init__(self, hutil=None): super(FakePatching,self).__init__(hutil) self.pkg_query_cmd = 'dpkg-query -L' self.gap_between_stage = 1 self.download_duration = 60 self.security_download_list = ['a', 'b', 'c', 'd', 'e'] self.all_download_list = ['1', '2', '3', '4', 'a', 'b', 'c', 'd', 'e'] def install(self): """ Install for dependencies. """ pass def check(self, category): """ Check valid upgrades, Return the package list to download & upgrade """ if category == 'important': return 0, self.security_download_list else: return 0, self.all_download_list def download_package(self, package): time.sleep(11) return 0 def patch_package(self, package): return 0 def check_reboot(self): return False ================================================ FILE: OSPatching/test/FakePatching3.py ================================================ #!/usr/bin/python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from AbstractPatching import AbstractPatching sys.path.append('../patch') class FakePatching(AbstractPatching): def __init__(self, hutil=None): super(FakePatching,self).__init__(hutil) self.pkg_query_cmd = 'dpkg-query -L' self.gap_between_stage = 20 self.download_duration = 60 self.security_download_list = ['a', 'b', 'c', 'd', 'e'] self.all_download_list = ['1', '2', '3', '4', 'a', 'b', 'c', 'd', 'e'] def install(self): """ Install for dependencies. """ pass def check(self, category): """ Check valid upgrades, Return the package list to download & upgrade """ if category == 'important': return 0, self.security_download_list else: return 0, self.all_download_list def download_package(self, package): return 0 def patch_package(self, package): return 0 def check_reboot(self): return False ================================================ FILE: OSPatching/test/HandlerEnvironment.json ================================================ [{ "name": "Microsoft.OSTCExtensions.OSPatchingForLinuxTest", "seqNo": "0", "version": 1.0, "handlerEnvironment": { "logFolder": ".", "configFolder": "./config", "statusFolder": "./status", "heartbeatFile": "./heartbeat.log"}}] ================================================ FILE: OSPatching/test/README.txt ================================================ In some distros, python has to be upgraded to Python2.7 copy test.crt and test.prv into /var/lib/waagent/ Run "./prepare_settings.py; ./test_handler_1.py" Run "./prepare_settings.py; ./test_handler_2.py" Run "./prepare_settings.py; ./test_handler_3.py" ================================================ FILE: OSPatching/test/check.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import datetime def main(): intervalOfWeeks = int(sys.argv[1]) if intervalOfWeeks == 1: sys.exit(0) history_scheduled = os.path.join(os.getcwd(), 'scheduled/history') today = datetime.date.today() today_dayOfWeek = today.strftime("%a") last_scheduled_date = None with open(history_scheduled) as f: lines = f.readlines() lines.reverse() for line in lines: line = line.strip() if line.endswith(today_dayOfWeek): last_scheduled_date = datetime.datetime.strptime(line, '%Y-%m-%d %a') break if last_scheduled_date is not None and last_scheduled_date.date() + datetime.timedelta(days=intervalOfWeeks*7) > today: sys.exit(1) else: sys.exit(0) if __name__ == '__main__': main() ================================================ FILE: OSPatching/test/config/0.settings ================================================ {"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"test","protectedSettings":"MIICmwYJKoZIhvcNAQcDoIICjDCCAogCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgRXh0ZW5zaW9ucwIQHIAxlZWZBI1AXqEZ5v5FPjANBgkqhkiG9w0BAQEFAASCAQAEBPkZsa4VN2rr5SBkMDGD8r/Rbp4W4l0cOV7gN96cQi2oWk7tnAGmz/Yr38OJGv+r7ilG4DP7EJAs2gNmnld8SvQsjI4TMAF6Rt6Xbc9yQiE8PblDXTLqIr/IenK8xIvItsWwDQHiJMLB1EDfyOnwYgnUxBpQYSR3PqySEmBMtQMy7BH6egfOhrd/eifSUew6kv/Zl2wP5DTsU8A8BufiCbuG9rwEhIdDVVDmL1jLQK52OobQaS2IkYa+v+d5bBfDEmJMvVjRqeiwfkXcraWHsHcJBmBLeb/AIxzS4oCx24K5025VbGv3SEHsKx1LIA5EA6+PEhYsT3Vi7JFKAa0VMIIBDAYJKoZIhvcNAQcBMBQGCCqGSIb3DQMHBAj4oY4VX4QKoYCB6GH7cWNvfJCjaNAB5uVXgMWMFbqc9c+CX4k7zqm+fdti9j3mYPpgT/Qs2Z8vrXHFU815T8erezXNijPVyG7C6g6foyzXa3pduB16/4GlMIWYTmfzmSEZZ8Qq0MkgKuq0xQQK5GnfZkCj1hZM5m9WU+2RQZKtAjU8BS8n/os/nCcv9IwOKJ7wyql9qe+j1ZFKrar8bT+evei900g0bNpPba3R1u5yx70e/JLRF5sYBju1PDOua+gV/PqtGY7UTUTWq2r3fLg+ziJMUShYRbtIVUKmxGSc6kDCGmuNNPQsnmh7+wqlBtN60Sw=","publicSettings":{}}}]} ================================================ FILE: OSPatching/test/default.settings ================================================ {"dayOfWeek": "everyday", "rebootAfterPatch": "rebootifneed", "storageAccountKey": "<TOCHANGE>", "stop": "false", "vmStatusTest": {"local": "true", "idleTestScript": "#!/usr/bin/python\n # Locally.\n def is_vm_idle():\n return True\n ", "healthyTestScript": "#!/usr/bin/python\n # Locally.\n def is_vm_healthy():\n return True\n "}, "disabled": "false", "startTime": "03:00", "category": "important", "intervalOfWeeks": "1", "installDuration": "00:30", "storageAccountName": "<TOCHANGE>", "oneoff": "false"} ================================================ FILE: OSPatching/test/handler.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import re import time import chardet import tempfile import urllib2 import urlparse import shutil import traceback import logging from azure.storage import BlobService from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util from patch import * # Global variables definition ExtensionShortName = 'OSPatching' DownloadDirectory = 'download' idleTestScriptName = "idleTest.py" healthyTestScriptName = "healthyTest.py" idleTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_idle(): return True """ healthyTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_healthy(): return True """ idleTestScriptGithub = "https://raw.githubusercontent.com/bingosummer/scripts/master/idleTest.py" healthyTestScriptGithub = "https://raw.githubusercontent.com/bingosummer/scripts/master/healthyTest.py" idleTestScriptStorage = "https://binxia.blob.core.windows.net/ospatching-v2/idleTest.py" healthyTestScriptStorage = "https://binxia.blob.core.windows.net/ospatching-v2/healthyTest.py" public_settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "RebootIfNeed", "category" : "ImportantAndRecommended", "installDuration" : "00:30", "oneoff" : "false", "intervalOfWeeks" : "1", "dayOfWeek" : "everyday", "startTime" : "03:00", "vmStatusTest" : { "local" : "true", "idleTestScript" : idleTestScriptLocal, #idleTestScriptStorage, "healthyTestScript" : healthyTestScriptLocal, #healthyTestScriptStorage } } protected_settings = { "storageAccountName" : "<TOCHANGE>", "storageAccountKey" : "<TOCHANGE>" } def install(): hutil.do_parse_context('Install') try: MyPatching.install() hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to install the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '0', 'Install Failed.') def enable(): hutil.do_parse_context('Enable') try: # protected_settings = hutil.get_protected_settings() # public_settings = hutil.get_public_settings() settings = protected_settings.copy() settings.update(public_settings) MyPatching.parse_settings(settings) # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() oneoff = settings.get("oneoff") download_customized_vmstatustest() copy_vmstatustestscript(hutil.get_seq_no(), oneoff) MyPatching.enable() current_config = MyPatching.get_current_config() hutil.do_exit(0, 'Enable', 'success', '0', 'Enable Succeeded. Current Configuration: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to enable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable Failed. Current Configuation: ' + current_config) def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded.') def disable(): hutil.do_parse_context('Disable') try: # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() MyPatching.disable() hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to disable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Disable', 'error', '0', 'Disable Failed.') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded.') def download(): hutil.do_parse_context('Download') try: # protected_settings = hutil.get_protected_settings() # public_settings = hutil.get_public_settings() settings = protected_settings.copy() settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.download() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Download Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to download updates with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Download Failed. Current Configuation: ' + current_config) def patch(): hutil.do_parse_context('Patch') try: # protected_settings = hutil.get_protected_settings() # public_settings = hutil.get_public_settings() settings = protected_settings.copy() settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.patch() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Patch Failed. Current Configuation: ' + current_config) def oneoff(): hutil.do_parse_context('Oneoff') try: # protected_settings = hutil.get_protected_settings() # public_settings = hutil.get_public_settings() settings = protected_settings.copy() settings.update(public_settings) MyPatching.parse_settings(settings) MyPatching.patch_one_off() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Oneoff Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to one-off patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Oneoff Patch Failed. Current Configuation: ' + current_config) def download_files(hutil): # protected_settings = hutil.get_protected_settings() # public_settings = hutil.get_public_settings() settings = protected_settings.copy() settings.update(public_settings) local = settings.get("vmStatusTest", dict()).get("local", "") if local.lower() == "true": local = True elif local.lower() == "false": local = False else: hutil.log_and_syslog(logging.WARNING, "The parameter \"local\" " "is empty or invalid. Set it as False. Continue...") local = False idle_test_script = settings.get("vmStatusTest", dict()).get('idleTestScript') healthy_test_script = settings.get("vmStatusTest", dict()).get('healthyTestScript') if (not idle_test_script and not healthy_test_script): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" and \"healthyTestScript\" " "are both empty. Exit downloading VMStatusTest scripts...") return elif local: if (idle_test_script and idle_test_script.startswith("http")) or \ (healthy_test_script and healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should not be uri. Exit downloading VMStatusTest scripts...") return elif not local: if (idle_test_script and not idle_test_script.startswith("http")) or \ (healthy_test_script and not healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should be uri. Exit downloading VMStatusTest scripts...") return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading VMStatusTest scripts...') vmStatusTestScripts = dict() vmStatusTestScripts[idle_test_script] = idleTestScriptName vmStatusTestScripts[healthy_test_script] = healthyTestScriptName if local: hutil.log_and_syslog(logging.INFO, "Saving VMStatusTest scripts from user's configurations...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = save_local_file(src, dst, hutil) preprocess_files(file_path, hutil) return storage_account_name = None storage_account_key = None if settings: storage_account_name = settings.get("storageAccountName", "").strip() storage_account_key = settings.get("storageAccountKey", "").strip() if storage_account_name and storage_account_key: hutil.log_and_syslog(logging.INFO, "Downloading VMStatusTest scripts from azure storage...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_blob(storage_account_name, storage_account_key, src, dst, hutil) preprocess_files(file_path, hutil) elif not(storage_account_name or storage_account_key): hutil.log_and_syslog(logging.INFO, "No azure storage account and key specified in protected " "settings. Downloading VMStatusTest scripts from external links...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_external_file(src, dst, hutil) preprocess_files(file_path, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account or storage key is not provided" hutil.log_and_syslog(logging.ERROR, error_msg) raise ValueError(error_msg) def download_blob(storage_account_name, storage_account_key, blob_uri, dst, hutil): seqNo = hutil.get_seq_no() container_name = get_container_name_from_uri(blob_uri) blob_name = get_blob_name_from_uri(blob_uri) download_dir = prepare_download_dir(seqNo) download_path = os.path.join(download_dir, dst) #Guest agent already ensure the plugin is enabled one after another. #The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key) try: blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download blob with uri:{0} " "with error {1}").format(blob_uri,e)) raise return download_path def download_external_file(uri, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: download_and_save_file(uri, file_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download external file with uri:{0} " "with error {1}").format(uri, e)) raise return file_path def save_local_file(src, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: waagent.SetFileContents(file_path, src) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to save file from user's configuration " "with error {0}").format(e)) raise return file_path def preprocess_files(file_path, hutil): """ Preprocess the text file. If it is a binary file, skip it. """ is_text, code_type = is_text_file(file_path) if is_text: dos2unix(file_path) hutil.log_and_syslog(logging.INFO, "Converting text files from DOS to Unix formats: Done") if code_type in ['UTF-8', 'UTF-16LE', 'UTF-16BE']: remove_bom(file_path) hutil.log_and_syslog(logging.INFO, "Removing BOM: Done") def is_text_file(file_path): with open(file_path, 'rb') as f: contents = f.read(512) return is_text(contents) def is_text(contents): supported_encoding = ['ascii', 'UTF-8', 'UTF-16LE', 'UTF-16BE'] code_type = chardet.detect(contents)['encoding'] if code_type in supported_encoding: return True, code_type else: return False, code_type def dos2unix(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rU') as f: contents = f.read() f_temp.write(contents) f_temp.close() shutil.move(temp_file_path, file_path) def remove_bom(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rb') as f: contents = f.read() for encoding in ["utf-8-sig", "utf-16"]: try: f_temp.write(contents.decode(encoding).encode('utf-8')) break except UnicodeDecodeError: continue f_temp.close() shutil.move(temp_file_path, file_path) def download_and_save_file(uri, file_path): src = urllib2.urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_path_from_uri(uriStr): uri = urlparse.urlparse(uriStr) return uri.path def get_blob_name_from_uri(uri): return get_properties_from_uri(uri)['blob_name'] def get_container_name_from_uri(uri): return get_properties_from_uri(uri)['container_name'] def get_properties_from_uri(uri): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.log_and_syslog(logging.ERROR, "Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def download_customized_vmstatustest(): download_dir = prepare_download_dir(hutil.get_seq_no()) maxRetry = 2 for retry in range(0, maxRetry + 1): try: download_files(hutil) break except Exception: hutil.log_and_syslog(logging.ERROR, "Failed to download files, retry=" + str(retry) + ", maxRetry=" + str(maxRetry)) if retry != maxRetry: hutil.log_and_syslog(logging.INFO, "Sleep 10 seconds") time.sleep(10) else: raise def copy_vmstatustestscript(seqNo, oneoff): src_dir = prepare_download_dir(seqNo) for filename in (idleTestScriptName, healthyTestScriptName): src = os.path.join(src_dir, filename) if oneoff is not None and oneoff.lower() == "false": dst = "oneoff" else: dst = "scheduled" dst = os.path.join(os.getcwd(), dst) if os.path.isfile(src): shutil.copy(src, dst) def delete_current_vmstatustestscript(): for filename in (idleTestScriptName, healthyTestScriptName): current_vmstatustestscript = os.path.join(os.getcwd(), "patch/"+filename) if os.path.isfile(current_vmstatustestscript): os.remove(current_vmstatustestscript) # Main function is the only entrance to this extension handler def main(): waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName) global MyPatching MyPatching = GetMyPatching(hutil) if MyPatching is None: sys.exit(1) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(download)", a): download() elif re.match("^([-/]*)(patch)", a): patch() elif re.match("^([-/]*)(oneoff)", a): oneoff() if __name__ == '__main__': main() ================================================ FILE: OSPatching/test/oneoff/__init__.py ================================================ ================================================ FILE: OSPatching/test/prepare_settings.py ================================================ #!/usr/bin/python import json idleTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_idle(): return True """ healthyTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_healthy(): return True """ idleTestScriptGithub = "https://raw.githubusercontent.com/bingosummer/scripts/master/idleTest.py" healthyTestScriptGithub = "https://raw.githubusercontent.com/bingosummer/scripts/master/healthyTest.py" idleTestScriptStorage = "https://binxia.blob.core.windows.net/ospatching-v2/idleTest.py" healthyTestScriptStorage = "https://binxia.blob.core.windows.net/ospatching-v2/healthyTest.py" settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "00:30", "oneoff" : "false", "intervalOfWeeks" : "1", "dayOfWeek" : "everyday", "startTime" : "03:00", "vmStatusTest" : { "local" : "true", "idleTestScript" : idleTestScriptLocal, #idleTestScriptStorage, "healthyTestScript" : healthyTestScriptLocal, #healthyTestScriptStorage }, "storageAccountName" : "<TOCHANGE>", "storageAccountKey" : "<TOCHANGE>" } settings_string = json.dumps(settings) settings_file = "default.settings" with open(settings_file, "w") as f: f.write(settings_string) ================================================ FILE: OSPatching/test/scheduled/__init__.py ================================================ ================================================ FILE: OSPatching/test/scheduled/history ================================================ ================================================ FILE: OSPatching/test/test.crt ================================================ Bag Attributes localKeyID: 01 00 00 00 friendlyName: ospatch-s131 subject=/DC=Windows Azure Service Management for Extensions issuer=/DC=Windows Azure Service Management for Extensions -----BEGIN CERTIFICATE----- MIIDCjCCAfKgAwIBAgIQHIAxlZWZBI1AXqEZ5v5FPjANBgkqhkiG9w0BAQUFADBB MT8wPQYKCZImiZPyLGQBGRYvV2luZG93cyBBenVyZSBTZXJ2aWNlIE1hbmFnZW1l bnQgZm9yIEV4dGVuc2lvbnMwHhcNMTUwNTE0MDMxMDU2WhcNMjAwNTE0MDMxMDU2 WjBBMT8wPQYKCZImiZPyLGQBGRYvV2luZG93cyBBenVyZSBTZXJ2aWNlIE1hbmFn ZW1lbnQgZm9yIEV4dGVuc2lvbnMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK AoIBAQCjtXmqCuqgRZ1nkEzQowNJbtWGqWg+1lTqaS3w/SsQ6K0fjuu1do8jNSuP NLPmY1o/96OA+7HoO4MyE2QfCzb7pGKIH0UPj/0u5HkR9NfRKG+LcZ6saoJQQDbP mdMqN8rTAyiH/Ks95rx5LzlSVX5QL9QtV11fSB9B/ILO5ebQIVAehAchFnSnUGqy HkhQPW8XOAmR4WarW3itaFhKmsbuXwCwbePwcBBhOxqyqqYwGG85zhOSj6xHKDep qF+UTACBd7Ei4SNme6DMDndNNplSLZOswyp+9ElmE01Eu98CtJN6FbrJ1qZU22EV 85Dz4l1UF4zD7JOb5d1XM/l56YEnAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAD5x XZrheNS+n2pCav+VuGrB5gVs9NrH8hZAXxIFQ8bMNRE7HTrUIpSQ04dZBlpo2kVI v1Fx0XPcV9pm22ySzQdxGOVPQqUWzhIVBYqz4gdH2zPSijysJstFPtGK+Z7ygnWA u0NCfpYJhy7hNv8/No7+J5M+BwKrBJUoIHCvrvE1gP97ZrcUD1XsIvOe4yvGEkp4 lydb1Djc1E+BzmI+MwL4BbPnGyBgBqAhSiNAa47Pp9OQhIyvCiifGC3QAkT5NMmq C+fY3AG2SdHY+39zYtehYyhUP9wKo2d/ecpx79ruE4HYJME6AuLVTRXqzQijFfPz M9ouI2lVvsL6DpRRby0= -----END CERTIFICATE----- ================================================ FILE: OSPatching/test/test.prv ================================================ Bag Attributes localKeyID: 01 00 00 00 Microsoft CSP Name: Microsoft Enhanced Cryptographic Provider v1.0 Key Attributes X509v3 Key Usage: 10 -----BEGIN PRIVATE KEY----- MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCjtXmqCuqgRZ1n kEzQowNJbtWGqWg+1lTqaS3w/SsQ6K0fjuu1do8jNSuPNLPmY1o/96OA+7HoO4My E2QfCzb7pGKIH0UPj/0u5HkR9NfRKG+LcZ6saoJQQDbPmdMqN8rTAyiH/Ks95rx5 LzlSVX5QL9QtV11fSB9B/ILO5ebQIVAehAchFnSnUGqyHkhQPW8XOAmR4WarW3it aFhKmsbuXwCwbePwcBBhOxqyqqYwGG85zhOSj6xHKDepqF+UTACBd7Ei4SNme6DM DndNNplSLZOswyp+9ElmE01Eu98CtJN6FbrJ1qZU22EV85Dz4l1UF4zD7JOb5d1X M/l56YEnAgMBAAECgf9NUVCuRdhtvTDX0HnMW8jOEHLk35j45Rt4Mj5CxzwsNsGN IVaZ5x2pylGwoY2YDeKgNw4Gguw8QmP7Pc54ohDyOjqa1q6mGAErH7zyGDE9+w8l TKVdC2J2/7cJQnwe1+WGBc0s8WY62taRSRaCaLhzof1MryqB7XZ3BF5kfwpixhIg qJ9eS9CYNVdAzHYEsHG3EvqBQm4JojtRMdMpME1SbCoSoZB4NT4T4bzGfyYYzXZo 0LSgRyPwJFBC0TdbjpF9bvJaNT3jVuAk2g0rRdR/Zio5GmqhzQe8x2Vg2NH2DzZ8 ArM1ZtmvW46etx9umDkKZZRLEron+sZ0QhdNL1UCgYEA38rmR9zV98pCmV6/xMCX RDXtmOKD6cM8bWAHE7Dkb10vPuz8WtTpfpjriBF1W3dwCyRClibelmOItjYr0uli 84w2IWCYA9z5T8mT2ymDNNYl3cmLM5gk1Prnm2uCAtQbN3kS+NHaGSIF7eV7xjTo QyV3qYNf+R3z2FO47fNoYD0CgYEAu0Tvtvmwv9Rq/8DIUZp3bjWhldhiXtTGImwe ldXKxTbNpOA5pVxPxp7WnEGXAY3TeWWIWEuxRCt8J6GiRzW48LTKsbygKloM6dhb YJ1FIwUXrW8jofwwliLhTlCxde/2MUBA6BZGsJ4GJT+nRUcpkjDrgS2uKbW7+iLR Id2/+TMCgYAyDM29qq0L0udcJ62Z0jzCW5E8zQQVhr1/9KcAh2I/acbEOvohUla6 Inciokzt3ONpCn392MmVNsN/hNP+QoYH1AbTJig5TPVRG9L+g+U9LtufI5EHQ/KQ 02BzCPM1sLw5htFwZnZxgoNy9gzdgj2jrsB5X9FaBJHhgq/sP7DLPQKBgQCKkdIH RO+Cor2iDZasu23QQSMV7A2uOid6ZSKkoJPwJkM40yoUsB/fyrzm1qnUXouy8mxX WXsMBFlUQggARUJZ6o1pwzeI3yVbC9thvD3iUexZSznErQWOsrSg7JjDuhIkE3Vz xrf8DJJjkZxGaQfbwxMgfRq4hl9YEddKBfn9fQKBgGpCHiQ2R9EFm4KR2zgB05Me 0IGyD+cC6MHPiL7sYcDLmG7Y56AZExmR4tcXs/gG21/3kTJTMANJyh3349/f1/Ma xKqG2df6sY/JUDPGiY5X35QM3oFM5razS5M5+4aRlRDTp4gLO+PQ00JQAD3DpwJf NSEn3hd8Qa3aFEqyL/n6 -----END PRIVATE KEY----- ================================================ FILE: OSPatching/test/test_handler_1.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import re import time import chardet import tempfile import urllib2 import urlparse import platform import shutil import traceback import logging from azure.storage import BlobService from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util import json import unittest sys.path.append('..') from patch import * from FakePatching import FakePatching # Global variables definition ExtensionShortName = 'OSPatching' DownloadDirectory = 'download' idleTestScriptName = "idleTest.py" healthyTestScriptName = "healthyTest.py" handlerName = os.path.basename(sys.argv[0]) status_file = './status/0.status' log_file = './extension.log' settings_file = "default.settings" with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) idleTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_idle(): return True """ healthyTestScriptLocal = """ #!/usr/bin/python # Locally. def is_vm_healthy(): return True """ def install(): hutil.do_parse_context('Install') try: MyPatching.install() hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception, e: hutil.log_and_syslog(logging.ERROR, "Failed to install the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '0', 'Install Failed.') def enable(): hutil.do_parse_context('Enable') try: MyPatching.parse_settings(settings) # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() oneoff = settings.get("oneoff") download_customized_vmstatustest() copy_vmstatustestscript(hutil.get_seq_no(), oneoff) MyPatching.enable() current_config = MyPatching.get_current_config() hutil.do_exit(0, 'Enable', 'success', '0', 'Enable Succeeded. Current Configuration: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to enable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable Failed. Current Configuation: ' + current_config) def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded.') def disable(): hutil.do_parse_context('Disable') try: # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() MyPatching.disable() hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded.') except Exception, e: hutil.log_and_syslog(logging.ERROR, "Failed to disable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Disable', 'error', '0', 'Disable Failed.') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded.') def download(): hutil.do_parse_context('Download') try: MyPatching.parse_settings(settings) MyPatching.download() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Download Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to download updates with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Download Failed. Current Configuation: ' + current_config) def patch(): hutil.do_parse_context('Patch') try: MyPatching.parse_settings(settings) MyPatching.patch() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Patch Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Patch Failed. Current Configuation: ' + current_config) def oneoff(): hutil.do_parse_context('Oneoff') try: MyPatching.parse_settings(settings) MyPatching.patch_one_off() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Oneoff Patch Succeeded. Current Configuation: ' + current_config) except Exception, e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to one-off patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Oneoff Patch Failed. Current Configuation: ' + current_config) def download_files(hutil): local = settings.get("vmStatusTest", dict()).get("local", "") if local.lower() == "true": local = True elif local.lower() == "false": local = False else: hutil.log_and_syslog(logging.WARNING, "The parameter \"local\" " "is empty or invalid. Set it as False. Continue...") local = False idle_test_script = settings.get("vmStatusTest", dict()).get('idleTestScript') healthy_test_script = settings.get("vmStatusTest", dict()).get('healthyTestScript') if (not idle_test_script and not healthy_test_script): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" and \"healthyTestScript\" " "are both empty. Exit downloading VMStatusTest scripts...") return elif local: if (idle_test_script and idle_test_script.startswith("http")) or \ (healthy_test_script and healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should not be uri. Exit downloading VMStatusTest scripts...") return elif not local: if (idle_test_script and not idle_test_script.startswith("http")) or \ (healthy_test_script and not healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should be uri. Exit downloading VMStatusTest scripts...") return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading VMStatusTest scripts...') vmStatusTestScripts = dict() vmStatusTestScripts[idle_test_script] = idleTestScriptName vmStatusTestScripts[healthy_test_script] = healthyTestScriptName if local: hutil.log_and_syslog(logging.INFO, "Saving VMStatusTest scripts from user's configurations...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = save_local_file(src, dst, hutil) preprocess_files(file_path, hutil) return storage_account_name = None storage_account_key = None if settings: storage_account_name = settings.get("storageAccountName", "").strip() storage_account_key = settings.get("storageAccountKey", "").strip() if storage_account_name and storage_account_key: hutil.log_and_syslog(logging.INFO, "Downloading VMStatusTest scripts from azure storage...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_blob(storage_account_name, storage_account_key, src, dst, hutil) preprocess_files(file_path, hutil) elif not(storage_account_name or storage_account_key): hutil.log_and_syslog(logging.INFO, "No azure storage account and key specified in protected " "settings. Downloading VMStatusTest scripts from external links...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_external_file(src, dst, hutil) preprocess_files(file_path, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account or storage key is not provided" hutil.log_and_syslog(logging.ERROR, error_msg) raise ValueError(error_msg) def download_blob(storage_account_name, storage_account_key, blob_uri, dst, hutil): seqNo = hutil.get_seq_no() container_name = get_container_name_from_uri(blob_uri) blob_name = get_blob_name_from_uri(blob_uri) download_dir = prepare_download_dir(seqNo) download_path = os.path.join(download_dir, dst) #Guest agent already ensure the plugin is enabled one after another. #The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key) try: blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception, e: hutil.log_and_syslog(logging.ERROR, ("Failed to download blob with uri:{0} " "with error {1}").format(blob_uri,e)) raise return download_path def download_external_file(uri, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: download_and_save_file(uri, file_path) except Exception, e: hutil.log_and_syslog(logging.ERROR, ("Failed to download external file with uri:{0} " "with error {1}").format(uri, e)) raise return file_path def save_local_file(src, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: waagent.SetFileContents(file_path, src) except Exception, e: hutil.log_and_syslog(logging.ERROR, ("Failed to save file from user's configuration " "with error {0}").format(e)) raise return file_path def preprocess_files(file_path, hutil): """ Preprocess the text file. If it is a binary file, skip it. """ is_text, code_type = is_text_file(file_path) if is_text: dos2unix(file_path) hutil.log_and_syslog(logging.INFO, "Converting text files from DOS to Unix formats: Done") if code_type in ['UTF-8', 'UTF-16LE', 'UTF-16BE']: remove_bom(file_path) hutil.log_and_syslog(logging.INFO, "Removing BOM: Done") def is_text_file(file_path): with open(file_path, 'rb') as f: contents = f.read(512) return is_text(contents) def is_text(contents): supported_encoding = ['ascii', 'UTF-8', 'UTF-16LE', 'UTF-16BE'] code_type = chardet.detect(contents)['encoding'] if code_type in supported_encoding: return True, code_type else: return False, code_type def dos2unix(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rU') as f: contents = f.read() f_temp.write(contents) f_temp.close() shutil.move(temp_file_path, file_path) def remove_bom(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rb') as f: contents = f.read() for encoding in ["utf-8-sig", "utf-16"]: try: f_temp.write(contents.decode(encoding).encode('utf-8')) break except UnicodeDecodeError: continue f_temp.close() shutil.move(temp_file_path, file_path) def download_and_save_file(uri, file_path): src = urllib2.urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_path_from_uri(uriStr): uri = urlparse.urlparse(uriStr) return uri.path def get_blob_name_from_uri(uri): return get_properties_from_uri(uri)['blob_name'] def get_container_name_from_uri(uri): return get_properties_from_uri(uri)['container_name'] def get_properties_from_uri(uri): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.log_and_syslog(logging.ERROR, "Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def download_customized_vmstatustest(): download_dir = prepare_download_dir(hutil.get_seq_no()) maxRetry = 2 for retry in range(0, maxRetry + 1): try: download_files(hutil) break except Exception, e: hutil.log_and_syslog(logging.ERROR, "Failed to download files, retry=" + str(retry) + ", maxRetry=" + str(maxRetry)) if retry != maxRetry: hutil.log_and_syslog(logging.INFO, "Sleep 10 seconds") time.sleep(10) else: raise def copy_vmstatustestscript(seqNo, oneoff): src_dir = prepare_download_dir(seqNo) for filename in (idleTestScriptName, healthyTestScriptName): src = os.path.join(src_dir, filename) if os.path.isfile(src): if oneoff is not None and oneoff.lower() == "true": dst = "oneoff" else: dst = "scheduled" dst = os.path.join(os.getcwd(), dst) shutil.copy(src, dst) def delete_current_vmstatustestscript(): for filename in (idleTestScriptName, healthyTestScriptName): current_vmstatustestscript = os.path.join(os.getcwd(), "patch/"+filename) if os.path.isfile(current_vmstatustestscript): os.remove(current_vmstatustestscript) class Test(unittest.TestCase): def setUp(self): print '\n\n============================================================================================' waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching is None: sys.exit(1) distro = DistInfo()[0] if 'centos' in distro or 'Oracle' in distro or 'redhat' in distro: MyPatching.cron_restart_cmd = 'service crond restart' try: os.remove('mrseq') except: pass waagent.SetFileContents(MyPatching.package_downloaded_path, '') waagent.SetFileContents(MyPatching.package_patched_path, '') def test_case_insensitive_parameters(self): print 'test_case_insensitive_parameters' global settings settings = { "disabled" : "False", "stop" : "false", "rebootAfterPatch" : "rEbOoTiFnEeD", "category" : "imPortant", "installDuration" : "01:00", "oneoff" : "falSe", "dayOfWeek" : "Sunday|Monday|Tuesday|wednesday|Thursday|Friday|Saturday", "startTime" : "02:00" } MyPatching.parse_settings(settings) self.assertFalse(MyPatching.disabled) self.assertFalse(MyPatching.stop) self.assertEqual(MyPatching.reboot_after_patch, "rebootifneed") self.assertFalse(MyPatching.oneoff) self.assertEqual(MyPatching.day_of_week, [7, 1, 2, 3, 4, 5, 6]) self.assertEqual(MyPatching.category, "important") import datetime self.assertEqual(MyPatching.start_time, datetime.datetime.strptime("02:00", '%H:%M')) def test_illegal_parameters(self): print 'test_illegal_parameters' global settings settings = { "disabled" : "illegal", "stop" : "false", "rebootAfterPatch" : "illegal", "category" : "illegal", "installDuration" : "1 hour", "oneoff" : "illegal", "dayOfWeek" : "Sunday|Moy|Tday|wednesday|Thursday|Friday|Srday", "startTime" : "02:00" } MyPatching.parse_settings(settings) self.assertFalse(MyPatching.disabled) self.assertFalse(MyPatching.stop) self.assertEqual(MyPatching.reboot_after_patch, "rebootifneed") self.assertFalse(MyPatching.oneoff) self.assertEqual(MyPatching.day_of_week, range(1,8)) self.assertEqual(MyPatching.category, "important") import datetime self.assertEqual(MyPatching.start_time, datetime.datetime.strptime("02:00", '%H:%M')) def test_conflict_parameters_1(self): print 'test_conflict_parameters_1' global settings settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "01:01", "oneoff" : "false", "vmStatusTest" : { "local" : "true", "healthyTestScript" : "http://test.com/test.py" } } MyPatching.parse_settings(settings) old_log_len = len(waagent.GetFileContents(log_file)) download_customized_vmstatustest() log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('The parameter "idleTestScript" or "healthyTestScript" should not be uri' in log_contents) def test_conflict_parameters_2(self): print 'test_conflict_parameters_2' global settings settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "01:01", "oneoff" : "false", "vmStatusTest" : { "local" : "false", "healthyTestScript" : idleTestScriptLocal } } MyPatching.parse_settings(settings) old_log_len = len(waagent.GetFileContents(log_file)) download_customized_vmstatustest() log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('The parameter "idleTestScript" or "healthyTestScript" should be uri' in log_contents) def test_install(self): """ Each Distro has different dependencies for OSPatching Extension. It is MANUAL to check whether they are installed or not. Ubuntu : update-notifier-common CentOS/Oracle : yum-downloadonly yum-plugin-security SuSE : None """ print 'test_install' with self.assertRaises(SystemExit) as cm: install() self.assertEqual(cm.exception.code, 0) self.assertEqual(get_status("Install"), 'success') def test_enable(self): print 'test_enable' global settings settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "01:01", "oneoff" : "false", } with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertEqual(get_status("Enable"), 'success') download_cmd = 'python test_handler_1.py -download' patch_cmd = 'python test_handler_1.py -patch' crontab_content = waagent.GetFileContents('/etc/crontab') self.assertTrue(download_cmd in crontab_content) self.assertTrue(patch_cmd in crontab_content) def test_disable(self): print 'test_disable' global settings settings = {} with self.assertRaises(SystemExit) as cm: disable() self.assertEqual(cm.exception.code, 0) self.assertEqual(get_status("Disable"), 'success') download_cmd = 'python test_handler_1.py -download' patch_cmd = 'python test_handler_1.py -patch' crontab_content = waagent.GetFileContents('/etc/crontab') self.assertTrue(download_cmd not in crontab_content) self.assertTrue(patch_cmd not in crontab_content) def test_cron(self): print 'test_cron' global settings settings = {} enable_time = time.time() settings['startTime'] = time.strftime('%H:%M', time.localtime(enable_time + 180)) delta_time = int(time.strftime('%S', time.localtime(enable_time + 120))) MyPatching.download_duration = 60 with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertEqual(get_status("Enable"), 'success') download_cmd = " ".join(["python", handlerName, "-download"]) patch_cmd = " ".join(["python", handlerName, "-patch"]) crontab_content = waagent.GetFileContents('/etc/crontab') self.assertTrue(download_cmd in crontab_content) self.assertTrue(patch_cmd in crontab_content) time.sleep(180 + 5) distro = DistInfo()[0] if 'SuSE' in distro: find_cron = 'grep CRON /var/log/messages' elif 'Ubuntu' in distro: find_cron = 'grep CRON /var/log/syslog' else: find_cron = 'cat /var/log/cron' day = int(time.strftime('%d', time.localtime(enable_time))) find_download_time = "grep '" + str(day) + time.strftime(' %H:%M', time.localtime(enable_time + 120)) + "'" find_patch_time = "grep '" + str(day) + time.strftime(' %H:%M', time.localtime(enable_time + 180)) + "'" find_download = "grep '" + download_cmd + "'" find_patch = "grep '" + patch_cmd + "'" retcode, output = waagent.RunGetOutput(find_cron + ' | ' + find_download_time + ' | ' + find_download) self.assertTrue(output) retcode, output = waagent.RunGetOutput(find_cron + ' | ' + find_patch_time + ' | ' + find_patch) self.assertTrue(output) def test_download(self): """ Check file package.downloaded after download """ print 'test_download' global settings settings = { "category" : "importantandrecommended", } with self.assertRaises(SystemExit) as cm: download() self.assertEqual(cm.exception.code, 0) download_content = waagent.GetFileContents(MyPatching.package_downloaded_path) security_download_list = get_patch_list(MyPatching.package_downloaded_path, 'important') self.assertTrue(set(security_download_list) == set(MyPatching.security_download_list)) all_download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertTrue(set(all_download_list) == set(MyPatching.all_download_list)) def test_download_security(self): """ check file package.downloaded after download """ print 'test_download_security' global settings settings = { "category" : "important", } with self.assertRaises(SystemExit) as cm: download() self.assertEqual(cm.exception.code, 0) security_download_list = get_patch_list(MyPatching.package_downloaded_path, 'important') self.assertTrue(set(security_download_list) == set(MyPatching.security_download_list)) all_download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertTrue(set(all_download_list) == set(MyPatching.security_download_list)) def test_patch(self): ''' check file package.patched when patch successful ''' print 'test_patch' global settings settings = {} with self.assertRaises(SystemExit) as cm: download() self.assertEqual(cm.exception.code, 0) with self.assertRaises(SystemExit) as cm: patch() self.assertEqual(cm.exception.code, 0) download_content = waagent.GetFileContents(MyPatching.package_downloaded_path) patch_content = waagent.GetFileContents(MyPatching.package_patched_path) self.assertEqual(download_content, patch_content) def test_patch_failed(self): ''' check file package.patched when patch fail ''' print 'test_patch_failed' global settings settings = {} def patch_package(self): return 1 MyPatching.patch_package = patch_package old_log_len = len(waagent.GetFileContents(log_file)) with self.assertRaises(SystemExit) as cm: download() self.assertEqual(cm.exception.code, 0) with self.assertRaises(SystemExit) as cm: patch() log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertEqual(cm.exception.code, 0) patch_content = waagent.GetFileContents(MyPatching.package_patched_path) self.assertFalse(patch_content) self.assertTrue('Failed to patch the package' in log_contents) def test_patch_one_off(self): ''' check package.downloaded and package.patched when patch_one_off successful ''' print 'test_patch_one_off' global settings settings = { "oneoff" : "true", "category" : "importantandrecommended" } with self.assertRaises(SystemExit) as cm: oneoff() self.assertEqual(cm.exception.code, 0) self.assertEqual(get_status("Enable"), 'success') time.sleep(3) security_download_list = get_patch_list(MyPatching.package_downloaded_path, 'important') self.assertTrue(set(security_download_list) == set(MyPatching.security_download_list)) all_download_list = get_patch_list(MyPatching.package_patched_path) self.assertTrue(set(all_download_list) == set(MyPatching.all_download_list)) download_content = waagent.GetFileContents(MyPatching.package_downloaded_path) patch_content = waagent.GetFileContents(MyPatching.package_patched_path) self.assertEqual(patch_content, download_content) def test_patch_time_exceed(self): ''' check package.patched when patch time exceed ''' print 'test_patch_time_exceed' global settings settings = { "category" : "importantandrecommended", "installDuration" : "00:06" # 5 minutes reserved for reboot } old_log_len = len(waagent.GetFileContents(log_file)) def patch_package(self): time.sleep(11) return 0 MyPatching.patch_package = patch_package with self.assertRaises(SystemExit) as cm: download() self.assertEqual(cm.exception.code, 0) with self.assertRaises(SystemExit) as cm: patch() self.assertEqual(cm.exception.code, 0) patch_list = get_patch_list(MyPatching.package_patched_path) self.assertEqual(patch_list, ['a', 'b', 'c', 'd', 'e', '1']) log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('Patching time exceeded' in log_contents) def get_patch_list(file_path, category = None): content = waagent.GetFileContents(file_path) if category != None: result = [line.split()[0] for line in content.split('\n') if line.endswith(category)] else: result = [line.split()[0] for line in content.split('\n') if ' ' in line] return result def get_status(operation, retkey='status'): contents = waagent.GetFileContents(status_file) status = json.loads(contents)[0]['status'] if status['operation'] == operation: return status[retkey] return '' def change_settings(key, value): with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) with open(settings_file, "w") as f: settings[key] = value settings_string = json.dumps(settings) f.write(settings_string) return settings def main(): if len(sys.argv) == 1: unittest.main() return waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching == None: sys.exit(1) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(download)", a): download() elif re.match("^([-/]*)(patch)", a): patch() elif re.match("^([-/]*)(oneoff)", a): oneoff() if __name__ == '__main__': main() ================================================ FILE: OSPatching/test/test_handler_2.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import chardet import json import logging import os import re import shutil import sys import tempfile import time import traceback import urllib2 import urlparse import unittest from azure.storage import BlobService import Utils.HandlerUtil as Util from patch import * from FakePatching2 import FakePatching from Utils.WAAgentUtil import waagent sys.path.append('..') # Global variables definition ExtensionShortName = 'OSPatching' DownloadDirectory = 'download' idleTestScriptName = "idleTest.py" healthyTestScriptName = "healthyTest.py" handlerName = os.path.basename(sys.argv[0]) status_file = './status/0.status' log_file = './extension.log' settings_file = "default.settings" with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) def install(): hutil.do_parse_context('Install') try: MyPatching.install() hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to install the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '0', 'Install Failed.') def enable(): hutil.do_parse_context('Enable') try: MyPatching.parse_settings(settings) # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() oneoff = settings.get("oneoff") download_customized_vmstatustest() copy_vmstatustestscript(hutil.get_seq_no(), oneoff) MyPatching.enable() current_config = MyPatching.get_current_config() hutil.do_exit(0, 'Enable', 'success', '0', 'Enable Succeeded. Current Configuration: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to enable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable Failed. Current Configuation: ' + current_config) def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded.') def disable(): hutil.do_parse_context('Disable') try: # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() MyPatching.disable() hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to disable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Disable', 'error', '0', 'Disable Failed.') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded.') def download(): hutil.do_parse_context('Download') try: MyPatching.parse_settings(settings) MyPatching.download() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Download Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to download updates with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Download Failed. Current Configuation: ' + current_config) def patch(): hutil.do_parse_context('Patch') try: MyPatching.parse_settings(settings) MyPatching.patch() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Patch Failed. Current Configuation: ' + current_config) def oneoff(): hutil.do_parse_context('Oneoff') try: MyPatching.parse_settings(settings) MyPatching.patch_one_off() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Oneoff Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to one-off patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Oneoff Patch Failed. Current Configuation: ' + current_config) def download_files(hutil): local = settings.get("vmStatusTest", dict()).get("local", "") if local.lower() == "true": local = True elif local.lower() == "false": local = False else: hutil.log_and_syslog(logging.WARNING, "The parameter \"local\" " "is empty or invalid. Set it as False. Continue...") local = False idle_test_script = settings.get("vmStatusTest", dict()).get('idleTestScript') healthy_test_script = settings.get("vmStatusTest", dict()).get('healthyTestScript') if (not idle_test_script and not healthy_test_script): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" and \"healthyTestScript\" " "are both empty. Exit downloading VMStatusTest scripts...") return elif local: if (idle_test_script and idle_test_script.startswith("http")) or \ (healthy_test_script and healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should not be uri. Exit downloading VMStatusTest scripts...") return elif not local: if (idle_test_script and not idle_test_script.startswith("http")) or \ (healthy_test_script and not healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should be uri. Exit downloading VMStatusTest scripts...") return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading VMStatusTest scripts...') vmStatusTestScripts = dict() vmStatusTestScripts[idle_test_script] = idleTestScriptName vmStatusTestScripts[healthy_test_script] = healthyTestScriptName if local: hutil.log_and_syslog(logging.INFO, "Saving VMStatusTest scripts from user's configurations...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = save_local_file(src, dst, hutil) preprocess_files(file_path, hutil) return storage_account_name = None storage_account_key = None if settings: storage_account_name = settings.get("storageAccountName", "").strip() storage_account_key = settings.get("storageAccountKey", "").strip() if storage_account_name and storage_account_key: hutil.log_and_syslog(logging.INFO, "Downloading VMStatusTest scripts from azure storage...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_blob(storage_account_name, storage_account_key, src, dst, hutil) preprocess_files(file_path, hutil) elif not(storage_account_name or storage_account_key): hutil.log_and_syslog(logging.INFO, "No azure storage account and key specified in protected " "settings. Downloading VMStatusTest scripts from external links...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_external_file(src, dst, hutil) preprocess_files(file_path, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account or storage key is not provided" hutil.log_and_syslog(logging.ERROR, error_msg) raise ValueError(error_msg) def download_blob(storage_account_name, storage_account_key, blob_uri, dst, hutil): seqNo = hutil.get_seq_no() container_name = get_container_name_from_uri(blob_uri) blob_name = get_blob_name_from_uri(blob_uri) download_dir = prepare_download_dir(seqNo) download_path = os.path.join(download_dir, dst) #Guest agent already ensure the plugin is enabled one after another. #The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key) try: blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download blob with uri:{0} " "with error {1}").format(blob_uri,e)) raise return download_path def download_external_file(uri, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: download_and_save_file(uri, file_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download external file with uri:{0} " "with error {1}").format(uri, e)) raise return file_path def save_local_file(src, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: waagent.SetFileContents(file_path, src) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to save file from user's configuration " "with error {0}").format(e)) raise return file_path def preprocess_files(file_path, hutil): """ Preprocess the text file. If it is a binary file, skip it. """ is_text, code_type = is_text_file(file_path) if is_text: dos2unix(file_path) hutil.log_and_syslog(logging.INFO, "Converting text files from DOS to Unix formats: Done") if code_type in ['UTF-8', 'UTF-16LE', 'UTF-16BE']: remove_bom(file_path) hutil.log_and_syslog(logging.INFO, "Removing BOM: Done") def is_text_file(file_path): with open(file_path, 'rb') as f: contents = f.read(512) return is_text(contents) def is_text(contents): supported_encoding = ['ascii', 'UTF-8', 'UTF-16LE', 'UTF-16BE'] code_type = chardet.detect(contents)['encoding'] if code_type in supported_encoding: return True, code_type else: return False, code_type def dos2unix(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rU') as f: contents = f.read() f_temp.write(contents) f_temp.close() shutil.move(temp_file_path, file_path) def remove_bom(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rb') as f: contents = f.read() for encoding in ["utf-8-sig", "utf-16"]: try: f_temp.write(contents.decode(encoding).encode('utf-8')) break except UnicodeDecodeError: continue f_temp.close() shutil.move(temp_file_path, file_path) def download_and_save_file(uri, file_path): src = urllib2.urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_path_from_uri(uriStr): uri = urlparse.urlparse(uriStr) return uri.path def get_blob_name_from_uri(uri): return get_properties_from_uri(uri)['blob_name'] def get_container_name_from_uri(uri): return get_properties_from_uri(uri)['container_name'] def get_properties_from_uri(uri): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.log_and_syslog(logging.ERROR, "Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def download_customized_vmstatustest(): download_dir = prepare_download_dir(hutil.get_seq_no()) maxRetry = 2 for retry in range(0, maxRetry + 1): try: download_files(hutil) break except Exception: hutil.log_and_syslog(logging.ERROR, "Failed to download files, retry=" + str(retry) + ", maxRetry=" + str(maxRetry)) if retry != maxRetry: hutil.log_and_syslog(logging.INFO, "Sleep 10 seconds") time.sleep(10) else: raise def copy_vmstatustestscript(seqNo, oneoff): src_dir = prepare_download_dir(seqNo) for filename in (idleTestScriptName, healthyTestScriptName): src = os.path.join(src_dir, filename) if os.path.isfile(src): if oneoff is not None and oneoff.lower() == "true": dst = "oneoff" else: dst = "scheduled" dst = os.path.join(os.getcwd(), dst) shutil.copy(src, dst) def delete_current_vmstatustestscript(): for filename in (idleTestScriptName, healthyTestScriptName): current_vmstatustestscript = os.path.join(os.getcwd(), "patch/"+filename) if os.path.isfile(current_vmstatustestscript): os.remove(current_vmstatustestscript) class Test(unittest.TestCase): def setUp(self): print('\n\n============================================================================================') waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching is None: sys.exit(1) distro = DistInfo()[0] if 'centos' in distro or 'Oracle' in distro or 'redhat' in distro: MyPatching.cron_restart_cmd = 'service crond restart' try: os.remove('mrseq') except: pass waagent.SetFileContents(MyPatching.package_downloaded_path, '') waagent.SetFileContents(MyPatching.package_patched_path, '') def test_download_time_exceed(self): ''' check package.downloaded and package.patched ''' print('test_download_time_exceed') global settings current_time = time.time() settings = change_settings("startTime", time.strftime('%H:%M', time.localtime(current_time + 180))) settings = change_settings("category", "importantandrecommended") old_log_len = len(waagent.GetFileContents(log_file)) with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) time.sleep(180 + 10) all_download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertTrue(set(all_download_list) == set(['a', 'b', 'c', 'd', 'e'])) # Check extension.log log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('Download time exceeded' in log_contents) restore_settings() def test_stop_before_download(self): ''' check stop flag before download and after patch ''' print('test_stop_before_download') global settings current_time = time.time() settings = change_settings("startTime", time.strftime('%H:%M', time.localtime(current_time + 180))) settings = change_settings("category", "importantandrecommended") old_log_len = len(waagent.GetFileContents(log_file)) with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) os.remove('mrseq') settings = change_settings("stop", "true") with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertTrue(MyPatching.exists_stop_flag()) time.sleep(180 + 5 + 60) self.assertFalse(MyPatching.exists_stop_flag()) self.assertFalse(waagent.GetFileContents(MyPatching.package_downloaded_path)) self.assertFalse(waagent.GetFileContents(MyPatching.package_patched_path)) log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('Downloading patches is stopped/canceled' in log_contents) restore_settings() def test_stop_while_download(self): print('test_stop_while_download') global settings current_time = time.time() settings = change_settings("startTime", time.strftime('%H:%M', time.localtime(current_time + 180))) settings = change_settings("category", "importantandrecommended") old_log_len = len(waagent.GetFileContents(log_file)) delta_time = int(time.strftime('%S', time.localtime(current_time + 120))) with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) # set stop flag after downloaded 40 seconds time.sleep(160 - delta_time) os.remove('mrseq') settings = change_settings("stop", "true") with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertTrue(MyPatching.exists_stop_flag()) # Make sure the total sleep time is greater than 180s time.sleep(20 + delta_time + 5) self.assertFalse(MyPatching.exists_stop_flag()) download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertEqual(download_list, ['a', 'b', 'c']) self.assertFalse(waagent.GetFileContents(MyPatching.package_patched_path)) # Check extension.log log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('Installing patches is stopped/canceled' in log_contents) restore_settings() def get_patch_list(file_path, category = None): content = waagent.GetFileContents(file_path) if category != None: result = [line.split()[0] for line in content.split('\n') if line.endswith(category)] else: result = [line.split()[0] for line in content.split('\n') if ' ' in line] return result def get_status(operation, retkey='status'): contents = waagent.GetFileContents(status_file) status = json.loads(contents)[0]['status'] if status['operation'] == operation: return status[retkey] return '' def change_settings(key, value): with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) with open(settings_file, "w") as f: settings[key] = value settings_string = json.dumps(settings) f.write(settings_string) return settings def restore_settings(): idleTestScriptLocal = """#!/usr/bin/python # Locally. def is_vm_idle(): return True """ healthyTestScriptLocal = """#!/usr/bin/python # Locally. def is_vm_healthy(): return True """ settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "00:30", "oneoff" : "false", "intervalOfWeeks" : "1", "dayOfWeek" : "everyday", "startTime" : "03:00", "vmStatusTest" : { "local" : "true", "idleTestScript" : idleTestScriptLocal, #idleTestScriptStorage, "healthyTestScript" : healthyTestScriptLocal, #healthyTestScriptStorage }, "storageAccountName" : "<TOCHANGE>", "storageAccountKey" : "<TOCHANGE>" } settings_string = json.dumps(settings) settings_file = "default.settings" with open(settings_file, "w") as f: f.write(settings_string) def main(): if len(sys.argv) == 1: unittest.main() return waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." % ExtensionShortName) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching is None: sys.exit(1) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(download)", a): download() elif re.match("^([-/]*)(patch)", a): patch() elif re.match("^([-/]*)(oneoff)", a): oneoff() if __name__ == '__main__': main() ================================================ FILE: OSPatching/test/test_handler_3.py ================================================ #!/usr/bin/python # # OSPatching extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import re import time import chardet import tempfile import urllib2 import urlparse import shutil import traceback import logging from azure.storage import BlobService from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util import json import unittest from patch import * from FakePatching3 import FakePatching sys.path.append('..') # Global variables definition ExtensionShortName = 'OSPatching' DownloadDirectory = 'download' idleTestScriptName = "idleTest.py" healthyTestScriptName = "healthyTest.py" handlerName = os.path.basename(sys.argv[0]) status_file = './status/0.status' log_file = './extension.log' settings_file = "default.settings" with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) def install(): hutil.do_parse_context('Install') try: MyPatching.install() hutil.do_exit(0, 'Install', 'success', '0', 'Install Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to install the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Install', 'error', '0', 'Install Failed.') def enable(): hutil.do_parse_context('Enable') try: MyPatching.parse_settings(settings) # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() oneoff = settings.get("oneoff") download_customized_vmstatustest() copy_vmstatustestscript(hutil.get_seq_no(), oneoff) MyPatching.enable() current_config = MyPatching.get_current_config() hutil.do_exit(0, 'Enable', 'success', '0', 'Enable Succeeded. Current Configuration: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to enable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable Failed. Current Configuation: ' + current_config) def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', 'success', '0', 'Uninstall Succeeded.') def disable(): hutil.do_parse_context('Disable') try: # Ensure the same configuration is executed only once hutil.exit_if_seq_smaller() MyPatching.disable() hutil.do_exit(0, 'Disable', 'success', '0', 'Disable Succeeded.') except Exception as e: hutil.log_and_syslog(logging.ERROR, "Failed to disable the extension with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Disable', 'error', '0', 'Disable Failed.') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0, 'Update', 'success', '0', 'Update Succeeded.') def download(): hutil.do_parse_context('Download') try: MyPatching.parse_settings(settings) MyPatching.download() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Download Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to download updates with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Download Failed. Current Configuation: ' + current_config) def patch(): hutil.do_parse_context('Patch') try: MyPatching.parse_settings(settings) MyPatching.patch() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Patch Failed. Current Configuation: ' + current_config) def oneoff(): hutil.do_parse_context('Oneoff') try: MyPatching.parse_settings(settings) MyPatching.patch_one_off() current_config = MyPatching.get_current_config() hutil.do_exit(0,'Enable','success','0', 'Oneoff Patch Succeeded. Current Configuation: ' + current_config) except Exception as e: current_config = MyPatching.get_current_config() hutil.log_and_syslog(logging.ERROR, "Failed to one-off patch with error: %s, stack trace: %s" %(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable','error','0', 'Oneoff Patch Failed. Current Configuation: ' + current_config) def download_files(hutil): local = settings.get("vmStatusTest", dict()).get("local", "") if local.lower() == "true": local = True elif local.lower() == "false": local = False else: hutil.log_and_syslog(logging.WARNING, "The parameter \"local\" " "is empty or invalid. Set it as False. Continue...") local = False idle_test_script = settings.get("vmStatusTest", dict()).get('idleTestScript') healthy_test_script = settings.get("vmStatusTest", dict()).get('healthyTestScript') if (not idle_test_script and not healthy_test_script): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" and \"healthyTestScript\" " "are both empty. Exit downloading VMStatusTest scripts...") return elif local: if (idle_test_script and idle_test_script.startswith("http")) or \ (healthy_test_script and healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should not be uri. Exit downloading VMStatusTest scripts...") return elif not local: if (idle_test_script and not idle_test_script.startswith("http")) or \ (healthy_test_script and not healthy_test_script.startswith("http")): hutil.log_and_syslog(logging.WARNING, "The parameter \"idleTestScript\" or \"healthyTestScript\" " "should be uri. Exit downloading VMStatusTest scripts...") return hutil.do_status_report('Downloading','transitioning', '0', 'Downloading VMStatusTest scripts...') vmStatusTestScripts = dict() vmStatusTestScripts[idle_test_script] = idleTestScriptName vmStatusTestScripts[healthy_test_script] = healthyTestScriptName if local: hutil.log_and_syslog(logging.INFO, "Saving VMStatusTest scripts from user's configurations...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = save_local_file(src, dst, hutil) preprocess_files(file_path, hutil) return storage_account_name = None storage_account_key = None if settings: storage_account_name = settings.get("storageAccountName", "").strip() storage_account_key = settings.get("storageAccountKey", "").strip() if storage_account_name and storage_account_key: hutil.log_and_syslog(logging.INFO, "Downloading VMStatusTest scripts from azure storage...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_blob(storage_account_name, storage_account_key, src, dst, hutil) preprocess_files(file_path, hutil) elif not(storage_account_name or storage_account_key): hutil.log_and_syslog(logging.INFO, "No azure storage account and key specified in protected " "settings. Downloading VMStatusTest scripts from external links...") for src,dst in vmStatusTestScripts.items(): if not src: continue file_path = download_external_file(src, dst, hutil) preprocess_files(file_path, hutil) else: #Storage account and key should appear in pairs error_msg = "Azure storage account or storage key is not provided" hutil.log_and_syslog(logging.ERROR, error_msg) raise ValueError(error_msg) def download_blob(storage_account_name, storage_account_key, blob_uri, dst, hutil): seqNo = hutil.get_seq_no() container_name = get_container_name_from_uri(blob_uri) blob_name = get_blob_name_from_uri(blob_uri) download_dir = prepare_download_dir(seqNo) download_path = os.path.join(download_dir, dst) #Guest agent already ensure the plugin is enabled one after another. #The blob download will not conflict. blob_service = BlobService(storage_account_name, storage_account_key) try: blob_service.get_blob_to_path(container_name, blob_name, download_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download blob with uri:{0} " "with error {1}").format(blob_uri,e)) raise return download_path def download_external_file(uri, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: download_and_save_file(uri, file_path) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to download external file with uri:{0} " "with error {1}").format(uri, e)) raise return file_path def save_local_file(src, dst, hutil): seqNo = hutil.get_seq_no() download_dir = prepare_download_dir(seqNo) file_path = os.path.join(download_dir, dst) try: waagent.SetFileContents(file_path, src) except Exception as e: hutil.log_and_syslog(logging.ERROR, ("Failed to save file from user's configuration " "with error {0}").format(e)) raise return file_path def preprocess_files(file_path, hutil): """ Preprocess the text file. If it is a binary file, skip it. """ is_text, code_type = is_text_file(file_path) if is_text: dos2unix(file_path) hutil.log_and_syslog(logging.INFO, "Converting text files from DOS to Unix formats: Done") if code_type in ['UTF-8', 'UTF-16LE', 'UTF-16BE']: remove_bom(file_path) hutil.log_and_syslog(logging.INFO, "Removing BOM: Done") def is_text_file(file_path): with open(file_path, 'rb') as f: contents = f.read(512) return is_text(contents) def is_text(contents): supported_encoding = ['ascii', 'UTF-8', 'UTF-16LE', 'UTF-16BE'] code_type = chardet.detect(contents)['encoding'] if code_type in supported_encoding: return True, code_type else: return False, code_type def dos2unix(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rU') as f: contents = f.read() f_temp.write(contents) f_temp.close() shutil.move(temp_file_path, file_path) def remove_bom(file_path): temp_file_path = tempfile.mkstemp()[1] f_temp = open(temp_file_path, 'wb') with open(file_path, 'rb') as f: contents = f.read() for encoding in ["utf-8-sig", "utf-16"]: try: f_temp.write(contents.decode(encoding).encode('utf-8')) break except UnicodeDecodeError: continue f_temp.close() shutil.move(temp_file_path, file_path) def download_and_save_file(uri, file_path): src = urllib2.urlopen(uri) dest = open(file_path, 'wb') buf_size = 1024 buf = src.read(buf_size) while(buf): dest.write(buf) buf = src.read(buf_size) def prepare_download_dir(seqNo): download_dir_main = os.path.join(os.getcwd(), DownloadDirectory) create_directory_if_not_exists(download_dir_main) download_dir = os.path.join(download_dir_main, seqNo) create_directory_if_not_exists(download_dir) return download_dir def create_directory_if_not_exists(directory): """create directory if no exists""" if not os.path.exists(directory): os.makedirs(directory) def get_path_from_uri(uriStr): uri = urlparse.urlparse(uriStr) return uri.path def get_blob_name_from_uri(uri): return get_properties_from_uri(uri)['blob_name'] def get_container_name_from_uri(uri): return get_properties_from_uri(uri)['container_name'] def get_properties_from_uri(uri): path = get_path_from_uri(uri) if path.endswith('/'): path = path[:-1] if path[0] == '/': path = path[1:] first_sep = path.find('/') if first_sep == -1: hutil.log_and_syslog(logging.ERROR, "Failed to extract container, blob, from {}".format(path)) blob_name = path[first_sep+1:] container_name = path[:first_sep] return {'blob_name': blob_name, 'container_name': container_name} def download_customized_vmstatustest(): maxRetry = 2 for retry in range(0, maxRetry + 1): try: download_files(hutil) break except Exception: hutil.log_and_syslog(logging.ERROR, "Failed to download files, retry=" + str(retry) + ", maxRetry=" + str(maxRetry)) if retry != maxRetry: hutil.log_and_syslog(logging.INFO, "Sleep 10 seconds") time.sleep(10) else: raise def copy_vmstatustestscript(seqNo, oneoff): src_dir = prepare_download_dir(seqNo) for filename in (idleTestScriptName, healthyTestScriptName): src = os.path.join(src_dir, filename) if os.path.isfile(src): if oneoff is not None and oneoff.lower() == "true": dst = "oneoff" else: dst = "scheduled" dst = os.path.join(os.getcwd(), dst) shutil.copy(src, dst) def delete_current_vmstatustestscript(): for filename in (idleTestScriptName, healthyTestScriptName): current_vmstatustestscript = os.path.join(os.getcwd(), "patch/"+filename) if os.path.isfile(current_vmstatustestscript): os.remove(current_vmstatustestscript) class Test(unittest.TestCase): def setUp(self): print('\n\n============================================================================================') waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching is None: sys.exit(1) distro = DistInfo()[0] if 'centos' in distro or 'Oracle' in distro or 'redhat' in distro: MyPatching.cron_restart_cmd = 'service crond restart' try: os.remove('mrseq') except: pass waagent.SetFileContents(MyPatching.package_downloaded_path, '') waagent.SetFileContents(MyPatching.package_patched_path, '') def test_stop_between_download_and_stage1(self): print('test_stop_between_download_and_stage1') global settings current_time = time.time() settings = change_settings("startTime", time.strftime('%H:%M', time.localtime(current_time + 180))) settings = change_settings("category", "importantandrecommended") old_log_len = len(waagent.GetFileContents(log_file)) delta_time = int(time.strftime('%S', time.localtime(current_time + 120))) with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) # set stop flag after downloaded 40 seconds time.sleep(160 - delta_time) os.remove('mrseq') settings = change_settings("stop", "true") with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertTrue(MyPatching.exists_stop_flag()) # Make sure the total sleep time is greater than 180s time.sleep(20 + delta_time + 5 + 60) self.assertFalse(MyPatching.exists_stop_flag()) download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertEqual(download_list, ['a', 'b', 'c', 'd', 'e', '1', '2', '3', '4']) self.assertFalse(waagent.GetFileContents(MyPatching.package_patched_path)) log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue('Installing patches is stopped/canceled' in log_contents) restore_settings() def test_stop_between_stage1_and_stage2(self): print 'test_stop_between_stage1_and_stage2' global settings current_time = time.time() settings = change_settings("startTime", time.strftime('%H:%M', time.localtime(current_time + 180))) settings = change_settings("category", "importantandrecommended") old_log_len = len(waagent.GetFileContents(log_file)) delta_time = int(time.strftime('%S', time.localtime(current_time))) with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) # Set stop flag after patched 10 seconds # Meanwhile the extension is sleeping between stage 1 & 2 time.sleep(180 - delta_time + 10) os.remove('mrseq') settings = change_settings("stop", "true") with self.assertRaises(SystemExit) as cm: enable() self.assertEqual(cm.exception.code, 0) self.assertTrue(MyPatching.exists_stop_flag()) # The patching (stage 1 & 2) has ended time.sleep(20) self.assertFalse(MyPatching.exists_stop_flag()) download_list = get_patch_list(MyPatching.package_downloaded_path) self.assertEqual(download_list, ['a', 'b', 'c', 'd', 'e', '1', '2', '3', '4']) patch_list = get_patch_list(MyPatching.package_patched_path) self.assertEqual(patch_list, ['a', 'b', 'c', 'd', 'e']) log_contents = waagent.GetFileContents(log_file)[old_log_len:] self.assertTrue("Installing patches (Category:" + MyPatching.category_all + ") is stopped/canceled" in log_contents) restore_settings() def get_patch_list(file_path, category = None): content = waagent.GetFileContents(file_path) if category != None: result = [line.split()[0] for line in content.split('\n') if line.endswith(category)] else: result = [line.split()[0] for line in content.split('\n') if ' ' in line] return result def get_status(operation, retkey='status'): contents = waagent.GetFileContents(status_file) status = json.loads(contents)[0]['status'] if status['operation'] == operation: return status[retkey] return '' def change_settings(key, value): with open(settings_file, "r") as f: settings_string = f.read() settings = json.loads(settings_string) with open(settings_file, "w") as f: settings[key] = value settings_string = json.dumps(settings) f.write(settings_string) return settings def restore_settings(): idleTestScriptLocal = """#!/usr/bin/python # Locally. def is_vm_idle(): return True """ healthyTestScriptLocal = """#!/usr/bin/python # Locally. def is_vm_healthy(): return True """ settings = { "disabled" : "false", "stop" : "false", "rebootAfterPatch" : "rebootifneed", "category" : "important", "installDuration" : "00:30", "oneoff" : "false", "intervalOfWeeks" : "1", "dayOfWeek" : "everyday", "startTime" : "03:00", "vmStatusTest" : { "local" : "true", "idleTestScript" : idleTestScriptLocal, #idleTestScriptStorage, "healthyTestScript" : healthyTestScriptLocal, #healthyTestScriptStorage }, "storageAccountName" : "<TOCHANGE>", "storageAccountKey" : "<TOCHANGE>" } settings_string = json.dumps(settings) settings_file = "default.settings" with open(settings_file, "w") as f: f.write(settings_string) def main(): if len(sys.argv) == 1: unittest.main() return waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout') waagent.Log("%s started to handle." % ExtensionShortName) global hutil hutil = Util.HandlerUtility(waagent.Log, waagent.Error, ExtensionShortName) hutil.do_parse_context('TEST') global MyPatching MyPatching = FakePatching(hutil) if MyPatching is None: sys.exit(1) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(download)", a): download() elif re.match("^([-/]*)(patch)", a): patch() elif re.match("^([-/]*)(oneoff)", a): oneoff() if __name__ == '__main__': main() ================================================ FILE: OmsAgent/.gitignore ================================================ packages keys/keyring.gpg keys/keyring.gpg~ keys/.gnupg/ .vscode/ ext/future Utils/ waagent waagentc ================================================ FILE: OmsAgent/HandlerManifest.json ================================================ [ { "name": "OmsAgentForLinux", "version": "1.13.19", "handlerManifest": { "installCommand": "omsagent_shim.sh -install", "uninstallCommand": "omsagent_shim.sh -uninstall", "updateCommand": "omsagent_shim.sh -update", "enableCommand": "omsagent_shim.sh -enable", "disableCommand": "omsagent_shim.sh -disable", "rebootAfterInstall": false, "reportHeartbeat": false, "updateMode": "UpdateWithInstall", "continueOnUpdateFailure": "true" } } ] ================================================ FILE: OmsAgent/ImportGPGkey.sh ================================================ #!/bin/sh if [ -z "$1" ]; then echo "Usage:" echo " $0 PUBLIC_GPG_KEY" exit 1 fi if [ -z "$2" ]; then KEYRING_NAME="keyring.gpg" else KEYRING_NAME=$2 fi TARGET_DIR="$(dirname $1)" HOME=$TARGET_DIR gpg --no-default-keyring --keyring $TARGET_DIR/$KEYRING_NAME --import $1 RETVAL=$? # chown omsagent $TARGET_DIR/$KEYRING_NAME exit $RETVAL ================================================ FILE: OmsAgent/README.md ================================================ # [DEPRECATED] OmsAgent Extension > :warning: The Log Analytics agent has been **deprecated** and has no support as of **August 31, 2024.** If you use the Log Analytics agent to ingest data to Azure Monitor, [migrate now to the new Azure Monitor agent](https://docs.microsoft.com/en-us/azure/azure-monitor/agents/azure-monitor-agent-migration). > [See the latest version and extension-bundle mapping.](https://docs.microsoft.com/en-us/azure/virtual-machines/extensions/oms-linux#agent-and-vm-extension-version) You can read the User Guide below. * [Learn more: Azure Virtual Machine Extensions](https://azure.microsoft.com/en-us/documentation/articles/virtual-machines-extensions-features/) OmsAgent Extension can: * Install the omsagent * Onboard to a OMS workspace # User Guide ## 1. Configuration schema ### 1.1. Public configuration Schema for the public configuration file looks like this: * `workspaceId`: (required, string) the OMS workspace id to onboard to * `stopOnMultipleConnections`: (optional, true/false) warn and stop onboarding if the machine already has a workspace connection; defaults to false * `noDigest`: (optional, true/false) RPM manager skips verification of package or header digests when reading (same as running rpm --nodigest --nofiledigest) * `skipDockerProviderInstall`: (optional, true/false) if the value is true, then skips the installation of the docker provider; default value is false ```json { "workspaceId": "<workspace-id (guid)>", "stopOnMultipleConnections": true/false, "noDigest": true/false, "skipDockerProviderInstall": true/false } ``` ### 1.2. Protected configuration Schema for the protected configuration file looks like this: * `workspaceKey`: (required, string) the primary/secondary shared key of the workspace * `proxy`: (optional, string) the proxy connection string - of the form \[user:pass@\]host\[:port\] * `vmResourceId`: (optional, string) the full azure resource id of the vm - of the form /subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Compute/virtualMachines/{vmName} for Resource Manager VMs and of the form /subscriptions/{subscriptionId}/resourceGroups/{vmName}/providers/Microsoft.ClassicCompute/virtualMachines/{vmName} for Classic VMs ```json { "workspaceKey": "<workspace-key>", "proxy": "<proxy-string>", "vmResourceId": "<vm-resource-id>" } ``` ## 2. Deploying the Extension to a VM You can deploy it using Azure CLI, Azure Powershell and ARM template. ### 2.1. Using [**Azure CLI**][azure-cli] Before deploying OmsAgent Extension, you should configure your `public.json` and `protected.json` (in section 1.1 and 1.2 above). #### 2.1.1 Resource Manager You can deploy the OmsAgent Extension by running: ``` az vm extension set \ --resource-group myResourceGroup \ --vm-name myVM \ --name OmsAgentForLinux \ --publisher Microsoft.EnterpriseCloud.Monitoring \ --version <version> --protected-settings '{"workspaceKey": "omskey"}' \ --settings '{"workspaceId": "omsid"}' ``` #### 2.1.2 Classic Classic mode is used to managed legacy resources created outside of Resource Manager, and requires the [classic cli][azure-cli-classic] to manage via the command line. You need to enable Classic Mode (also called Azure Service Management mode) in the cli by running: ``` azure config mode asm ``` You can deploy the OmsAgent Extension by running: ``` azure vm extension set <vm-name> \ OmsAgentForLinux Microsoft.EnterpriseCloud.Monitoring <version> \ --public-config-path public.json \ --private-config-path protected.json ``` In the command above, you can change version with `'*'` to use latest version available, or `'1.*'` to get newest version that does not introduce non- breaking schema changes. To learn the latest version available, run: ``` azure vm extension list ``` ### 2.2. Using [**Azure Powershell**][azure-powershell] #### 2.2.1 Resource Manager You can login to your Azure account (Azure Resource Manager mode) by running: ```powershell Login-AzureRmAccount ``` Click [**HERE**](https://azure.microsoft.com/en-us/documentation/articles/powershell-azure-resource-manager/) to learn more about how to use Azure Powershell with Azure Resource Manager. You can deploy the OmsAgent Extension by running: ```powershell $RGName = '<resource-group-name>' $VmName = '<vm-name>' $Location = '<location>' $ExtensionName = 'OmsAgentForLinux' $Publisher = 'Microsoft.EnterpriseCloud.Monitoring' $Version = '<version>' $PublicConf = '{ "workspaceId": "<workspace id>", "stopOnMultipleConnections": true/false, "noDigest": true/false, "skipDockerProviderInstall": true/false }' $PrivateConf = '{ "workspaceKey": "<workspace key>", "proxy": "<proxy string>", "vmResourceId": "<vm resource id>" }' Set-AzureRmVMExtension -ResourceGroupName $RGName -VMName $VmName -Location $Location ` -Name $ExtensionName -Publisher $Publisher ` -ExtensionType $ExtensionName -TypeHandlerVersion $Version ` -Settingstring $PublicConf -ProtectedSettingString $PrivateConf ``` #### 2.2.2 Classic You can login to your Azure account (Azure Service Management mode) by running: ```powershell Add-AzureAccount ``` You can deploy the OmsAgent Extension by running: ```powershell $VmName = '<vm-name>' $vm = Get-AzureVM -ServiceName $VmName -Name $VmName $ExtensionName = 'OmsAgentForLinux' $Publisher = 'Microsoft.EnterpriseCloud.Monitoring' $Version = '<version>' $PublicConf = '{ "workspaceId": "<workspace id>", "stopOnMultipleConnections": true/false, "noDigest": true/false, "skipDockerProviderInstall": true/false }' $PrivateConf = '{ "workspaceKey": "<workspace key>", "proxy": "<proxy string>", "vmResourceId": "<vm resource id>" }' Set-AzureVMExtension -ExtensionName $ExtensionName -VM $vm ` -Publisher $Publisher -Version $Version ` -PrivateConfiguration $PrivateConf -PublicConfiguration $PublicConf | Update-AzureVM ``` ### 2.3. Using [**ARM Template**][arm-template] ```json { "type": "Microsoft.Compute/virtualMachines/extensions", "name": "<extension-deployment-name>", "apiVersion": "<api-version>", "location": "<location>", "dependsOn": [ "[concat('Microsoft.Compute/virtualMachines/', <vm-name>)]" ], "properties": { "publisher": "Microsoft.EnterpriseCloud.Monitoring", "type": "OmsAgentForLinux", "typeHandlerVersion": "1.4", "settings": { "workspaceId": "<workspace id>", "stopOnMultipleConnections": true/false, "noDigest": true/false, "skipDockerProviderInstall": true/false }, "protectedSettings": { "workspaceKey": "<workspace key>", "proxy": "<proxy string>", "vmResourceId": "<vm resource id>" } } } ``` ## 3. Scenarios ### 3.1 Onboard to OMS workspace ```json { "workspaceId": "MyWorkspaceId", "stopOnMultipleConnections": true, "noDigest": false, "skipDockerProviderInstall": true } ``` ```json { "workspaceKey": "MyWorkspaceKey", "proxy": "proxyuser:proxypassword@proxyserver:8080", "vmResourceId": "/subscriptions/c90fcea1-7cd5-4255-9e2e-25d627a2a259/resourceGroups/RGName/providers/Microsoft.Compute/virtualMachines/VMName" } ``` ## [Supported Linux Distributions](https://docs.microsoft.com/en-us/azure/azure-monitor/platform/log-analytics-agent#supported-linux-operating-systems) ## Troubleshooting * The status of the extension is reported back to Azure so that user can see the status on Azure Portal * All the execution output and errors generated by the extension are logged into the following directories - `/var/lib/waagent/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux-<version>/packages/`, `/opt/microsoft/omsagent/bin` and the tail of the output is logged into the log directory specified in HandlerEnvironment.json and reported back to Azure * The operation log of the extension is `/var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/<version>/extension.log` file. ### Common error codes and their meanings | Error Code | Meaning | Possible Action | | :---: | --- | --- | | 10 | VM is already connected to an OMS workspace | To connect the VM to the workspace specified in the extension schema, set stopOnMultipleConnections to false in public settings or remove this property. This VM gets billed once for each workspace it is connected to. | | 11 | Invalid config provided to the extension | Follow the preceding examples to set all property values necessary for deployment. | | 12 | The dpkg package manager is locked | Make sure all dpkg update operations on the machine have finished and retry. | | 20 | Enable called prematurely | [Update the Azure Linux Agent](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/update-agent) to the latest available version. | | 40-44 | Issue with the Automatic Management scenario | Please contact support with the details from the /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/\<version\>/extension.log | | 51 | This extension is not supported on the VM's operation system | | | 52 | The extension failed due to a missing dependency | | | 53 | The extension failed due to missing or wrong configuration parameters | | | 55 | Cannot connect to the Microsoft Operations Management Suite service | Check that the system either has Internet access, or that a valid HTTP proxy has been provided. Additionally, check the correctness of the workspace ID. | Additional error codes and troubleshooting information can be found on the [OMS-Agent-for-Linux Troubleshooting Guide](https://github.com/Microsoft/OMS-Agent-for-Linux/blob/master/docs/Troubleshooting.md#). [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli-classic]: https://docs.microsoft.com/en-us/cli/azure/install-classic-cli [azure-cli]: https://docs.microsoft.com/en-us/cli/azure/install-azure-cli [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: OmsAgent/apply_version.sh ================================================ #! /bin/bash source ./omsagent.version echo "OMS_EXTENSION_VERSION=$OMS_EXTENSION_VERSION" echo "OMS_SHELL_BUNDLE_VERSION=$OMS_SHELL_BUNDLE_VERSION" # updating HandlerManifest.json # check for "version": "1.12.5", sed -i "s/\"version\".*$/\"version\": \"$OMS_EXTENSION_VERSION\",/g" HandlerManifest.json # updating watcherutil.py # check OMSExtensionVersion = '1.12.5' sed -i "s/^OMSExtensionVersion = .*$/OMSExtensionVersion = '$OMS_EXTENSION_VERSION'/" watcherutil.py # updating omsagent.py # check BundleFileName = 'omsagent-0.0.0-0.universal.x64.sh' sed -i "s/^BundleFileName = .*$/BundleFileName = 'omsagent-$OMS_SHELL_BUNDLE_VERSION.universal.x64.sh'/" omsagent.py # updating manifest.xml # check <Version>...</Version> sed -i -e "s|<Version>[0-9a-z.]\{1,\}</Version>|<Version>$OMS_EXTENSION_VERSION</Version>|g" manifest.xml ================================================ FILE: OmsAgent/extension-test/README.md ================================================ # OMS Extension Automated Testing ## Requirements * If host machine is Windows: * Must active Windows Subsystem for Linux [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10) * Create a ssh key using [ssh-keygen](https://help.github.com/articles/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent/) * [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli?view=azure-cli-latest) * Putty PSCP * [Putty for Windows](https://www.putty.org/) * Putty tools for Linux: * For DPKG: 'sudo apt-get install putty-tools' * For RPM: 'sudo yum install putty-tools' * For SUSE: 'sudo zypper install putty-tools' * Python 2.7+ & [pip](https://pip.pypa.io/en/stable/installing/) * [Requests](http://docs.python-requests.org/en/master/), [ADAL](https://github.com/AzureAD/azure-activedirectory-library-for-python), [json2html](https://github.com/softvar/json2html), [rstr](https://pypi.org/project/rstr/) ```bash $ pip install requests adal json2html rstr ``` ## Images currently supported for testing: * CentOS 6 and 7 * Oracle Linux 6 and 7 * Debian 8 and 9 * Ubuntu 14.04, 16.04, and 18.04 * Red Hat 6 and 7 * SUSE 12 ## Running Tests ### Prepare #### Resources 1. Create a resource group that will be used to store all test resources 2. Create an Azure Key Vault to store test secrets 3. Create a Log Analytics workspace where your test VMs will send data - From the workspace blade, navigate to Settings > Advanced Settings > and note the workspace Id and Key for later 4. Create a network security group, preferably in West US 2 - From the NSG blade, navigate to Settings > Inbound Security Rules > Add - Use the following settings - `Source` – IP Addresses - `Source IP Addresses/CIDR ranges` – the IP of your host machine - `Source port ranges` – * - `Destination` – Any - `Destination port ranges` – 22 - `Protocol` – Any or TCP - `Action` – Allow - `Priority` – Lowest possible number - `Name` – AllowSSH - Add 5. [Increase your VM quota](https://docs.microsoft.com/en-us/azure/azure-supportability/resource-manager-core-quotas-request) to 15 in the region you will specify below in parameters.json 6. [Optional] Register your own AAD app to allow end-to-end verification script to access Microsoft REST APIs - Azure Portal > Azure Active Directory > App Registrations (Preview) > New Registration - `Name` – A name of your choice, can be changed later - `Supported Account Types` – Accounts in this organizational directory only (Microsoft) - `Redirect URI (Optional)` – Leave blank - Register - Use Application (client) ID value displayed in app overview to replace `<app-id>` in parameters.json - In blade of new registration > Certificates & Secrets > New Client Secret - `Description` – A descriptive word or phrase of your choice - `Expires` – Never - Add - *Copy down the new client secret value!* Use this to replace `<app-secret>` in parameters.json #### Parameters 1. In your Azure Key Vault, manually upload secrets with the following name-value pairings: - `<tenant>` – your AAD tenant, visible in Azure Portal > Azure Active Directory > Properties > Directory ID - `<app-id>`, `<app-secret>` – verify_e2e service principal ID, secret (available in OneNote document, or use the values from the app you optionally registered in step 6 above) - `<subscription-id>` – ID of subscription that hosts your desired Log Analytics test workspace - `<tenant-id>` – ID of your Azure AD tenant - `<workspace-id>`, `<workspace-key>` – Log Analytics test workspace ID, key 2. In parameters.json, fill in the following: - `<resource group>`, `<location>` – resource group, region (e.g. westus2) in which you want your VMs created - `<username>`, `<password>` – the VM username and password (see [requirements](https://docs.microsoft.com/en-us/azure/virtual-machines/windows/faq#what-are-the-password-requirements-when-creating-a-vm)) - `<nsg resource group>` – resource group of your NSG - `<nsg>` – NSG name - `<size>` – Standard_B1ms - `<workspace>` – name of the workspace you created - `<key vault>` – name of the Key Vault you created - `<old version>` - specific version of the extension (define as empty "" if not using) #### Other 1. Allow the end-to-end verification script to read your workspace - Open workspace in Azure Portal - Access control (IAM) > Add - `Role` – Reader - `Assign access to` – Azure AD user, group, or application - `Select` – verify_e2e - Save 2. Log in to Azure using the Azure CLI and set your subscription ```bash $ az login $ az account set --subscription subscription_name ``` 3. Custom Log Setup: - [Custom logs Docs](https://docs.microsoft.com/en-us/azure/log-analytics/log-analytics-data-sources-custom-logs) - Add custom.log file to setup Custom_Log_CL ![AddingCustomlogFile](pictures/AddingCustomlogFile.png?raw=true) - Add location of the file on containers i.e., '/var/log/custom.log' ![AddLocationofFile](pictures/AddLocationofFile.png?raw=true) - Add Custom_Log_CL tag ![AddingCustomlogTag](pictures/AddingCustomlogTag.png?raw=true) ### Run test scripts - Available modes: - default: No options needed. Runs the install & reinstall tests on the latest agent with a 10 min wait time before verification. - `long`: Runs the tests just like the default mode but add a very longer wait time - `autoupgrade`: Runs the tests just like the default mode but waits till the agent is updated to a new version and terminates if running for more than 26 hours. - `instantupgrade`: Install the older version first and runs the default tests after force upgrade to newer version - `debug`: AZ CLI commands run with '--verbose' by default. Add 'debug' after short/long to see complete debug logs of az cli #### All images in default mode ```bash $ python -u oms_extension_tests.py ``` #### All images in default mode with debug in long run ```bash $ python -u oms_extension_tests.py long debug ``` #### Subset of images ```bash $ python -u oms_extension_tests.py image1 image2 ... ``` #### Autoupgrade of images (This option will wait until the extension is upgraded to the new version and continue to next steps after verifying data) ```bash $ python -u oms_extension_tests.py autoupgrade image1 image2 ... ``` #### Instantupgrade of images (This option will install the desired older version of extension first and then force upgrade to the latest version) Note: Must define a proper value for the `old_version` in parameters.json file else the program will encounter an undefined typeHandler error. ```bash $ python -u oms_extension_tests.py instantupgrade image1 image2 ... ``` ================================================ FILE: OmsAgent/extension-test/oms_extension_tests.py ================================================ """ Test the OMS Agent on all or a subset of images. Setup: read parameters and setup HTML report Test: 1. Create vm and install agent 2. Wait for data to propagate to backend and check for data 3. Remove extension 4. Reinstall extension 5. Optionally, wait for hours and check data and extension status 6. Purge extension and delete vm Finish: compile HTML report and log file """ import json import os import os.path import subprocess import re import sys import rstr import glob import shutil from time import sleep from datetime import datetime, timedelta from platform import system from collections import OrderedDict from verify_e2e import check_e2e from json2html import * E2E_DELAY = 15 # Delay (minutes) before checking for data AUTOUPGRADE_DELAY = 15 # Delay (minutes) before rechecking the extension version LONG_DELAY = 250 # Delay (minutes) before rechecking extension images_list = { 'ubuntu14': 'Canonical:UbuntuServer:14.04.5-LTS:14.04.201808180', 'ubuntu16': 'Canonical:UbuntuServer:16.04-LTS:latest', 'ubuntu18': 'Canonical:UbuntuServer:18.04-LTS:latest', 'debian8': 'credativ:Debian:8:latest', 'debian9': 'credativ:Debian:9:latest', 'redhat6': 'RedHat:RHEL:6.9:latest', 'redhat7': 'RedHat:RHEL:7.3:latest', 'centos6': 'OpenLogic:CentOS:6.9:latest', 'centos7': 'OpenLogic:CentOS:7.5:latest', # 'oracle6': 'Oracle:Oracle-Linux:6.9:latest', 'oracle7': 'Oracle:Oracle-Linux:7.5:latest', 'sles12': 'SUSE:SLES:12-SP3:latest', 'sles15': 'SUSE:SLES:15:latest'} vmnames = [] images = {} install_times = {} runwith = '--verbose' os.system('touch ./omsfiles/omsresults.log') os.system('touch ./omsfiles/omsresults.html') os.system('touch ./omsfiles/omsresults.status') vms_list = [] if len(sys.argv) > 0: options = sys.argv[1:] vms_list = [ i for i in options if i not in ('long', 'debug', 'autoupgrade', 'instantupgrade')] is_long = 'long' in options runwith = '--debug' if 'debug' in options else '--verbose' if 'autoupgrade' in options and 'instantupgrade' in options: print("Select only one option from 'autoupgrade' and 'instantupgrade'. You cannot run both at the same time") exit() is_autoupgrade = 'autoupgrade' in options is_instantupgrade = 'instantupgrade' in options else: is_long = is_debug = is_autoupgrade = is_instantupgrade = False if vms_list: for vm in vms_list: vm_dict = { vm: images_list[vm] } images.update(vm_dict) else: images = images_list print("List of VMs & Image Sources added for testing: {}".format(images)) with open('{0}/parameters.json'.format(os.getcwd()), 'r') as f: parameters = f.read() if re.search(r'"<.*>"', parameters): print('Please replace placeholders in parameters.json') exit() parameters = json.loads(parameters) resource_group = parameters['resource group'] location = parameters['location'] username = parameters['username'] nsg = parameters['nsg'] nsg_resource_group = parameters['nsg resource group'] size = parameters['size'] # Preferred: 'Standard_B1ms' extension = 'OmsAgentForLinux' publisher = 'Microsoft.EnterpriseCloud.Monitoring' key_vault = parameters['key vault'] subscription = str(json.loads(subprocess.check_output('az keyvault secret show --name subscription-id --vault-name {0}'.format(key_vault), shell=True))["value"]) workspace_id = str(json.loads(subprocess.check_output('az keyvault secret show --name workspace-id --vault-name {0}'.format(key_vault), shell=True))["value"]) workspace_key = str(json.loads(subprocess.check_output('az keyvault secret show --name workspace-key --vault-name {0}'.format(key_vault), shell=True))["value"]) public_settings = { "workspaceId": workspace_id } private_settings = { "workspaceKey": workspace_key } nsg_uri = "/subscriptions/" + subscription + "/resourceGroups/" + nsg_resource_group + "/providers/Microsoft.Network/networkSecurityGroups/" + nsg ssh_private = parameters['ssh private'] ssh_public = ssh_private + '.pub' if parameters['old version']: old_version = parameters['old version'] # Sometimes Azure VM images become unavailable or are unavailable in certain regions, lets check... for distname, image in images.iteritems(): img_publisher, _, sku, _ = image.split(':') if subprocess.check_output('az vm image list --all --location {0} --publisher {1} --sku {2}'.format(location, img_publisher, sku), shell=True) == '[]\n': print('Could not find image for {0} in {1}, please double check VM image availability'.format(distname, location)) exit() else: print('VM image availability successfully validated') # Detect the host system and validate nsg if system() == 'Windows': if os.system('az network nsg show --resource-group {0} --name {1} --query "[?n]"'.format(nsg_resource_group, nsg)) == 0: print("Network Security Group successfully validated") elif system() == 'Linux': if os.system('az network nsg show --resource-group {0} --name {1} > /dev/null 2>&1'.format(nsg_resource_group, nsg)) == 0: print("Network Security Group successfully validated") else: print("""Please verify that the nsg or nsg resource group are valid and are in the right subscription. If there is no Network Security Group, please create new one. NSG is a must to create a VM in this testing.""") exit() # Remove intermediate log and html files os.system('rm -rf ./*.log ./*.html ./results 2> /dev/null') result_html_file = open("finalresult.html", 'a+') # Common logic to save command itself def write_log_command(log, cmd): print(cmd) log.write(cmd + '\n') log.write('-' * 40) log.write('\n') # Common logic to append a file to another def append_file(src, dest): f = open(src, 'r') dest.write(f.read()) f.close() # Get time difference in minutes and seconds def get_time_diff(timevalue1, timevalue2): timediff = timevalue2 - timevalue1 minutes, seconds = divmod(timediff.days * 86400 + timediff.seconds, 60) return minutes, seconds # Correct potential windows line endings with dos2unix command def dos_2_unix(): os.system('dos2unix ./omsfiles/*') # Secure copy required files from local to vm def copy_to_vm(dnsname, username, ssh_private, location): os.system("scp -i {0} -o StrictHostKeyChecking=no -o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -r omsfiles/* {1}@{2}.{3}.cloudapp.azure.com:/tmp/".format(ssh_private, username, dnsname.lower(), location)) # Secure copy files from vm to local def copy_from_vm(dnsname, username, ssh_private, location, filename): os.system("scp -i {0} -o StrictHostKeyChecking=no -o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -r {1}@{2}.{3}.cloudapp.azure.com:/home/scratch/{4} omsfiles/.".format(ssh_private, username, dnsname.lower(), location, filename)) # Run scripts on vm using AZ CLI def run_command(resource_group, vmname, commandid, script): os.system('az vm run-command invoke -g {0} -n {1} --command-id {2} --scripts "{3}" {4}'.format(resource_group, vmname, commandid, script, runwith)) # Create vm using AZ CLI def create_vm(resource_group, vmname, image, username, ssh_public, location, dnsname, vmsize, nsg_uri): os.system('az vm create -g {0} -n {1} --image {2} --admin-username {3} --ssh-key-value @{4} --location {5} --public-ip-address-dns-name {6} --size {7} --nsg {8} {9}'.format(resource_group, vmname, image, username, ssh_public, location, dnsname, vmsize, nsg_uri, runwith)) # Add extension to vm using AZ CLI def add_extension(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option): os.system('az vm extension set -n {0} --publisher {1} --vm-name {2} --resource-group {3} --protected-settings "{4}" --settings "{5}" {6} {7}'.format(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option, runwith)) # Delete extension from vm using AZ CLI def delete_extension(extension, vmname, resource_group): os.system('az vm extension delete -n {0} --vm-name {1} --resource-group {2} {3}'.format(extension, vmname, resource_group, runwith)) # Get vm details using AZ CLI def get_vm_resources(resource_group, vmname): vm_cli_out = json.loads(subprocess.check_output('az vm show -g {0} -n {1}'.format(resource_group, vmname), shell=True)) os_disk = vm_cli_out['storageProfile']['osDisk']['name'] nic_name = vm_cli_out['networkProfile']['networkInterfaces'][0]['id'].split('/')[-1] ip_list = json.loads(subprocess.check_output('az vm list-ip-addresses -n {0} -g {1}'.format(vmname, resource_group), shell=True)) ip_name = ip_list[0]['virtualMachine']['network']['publicIpAddresses'][0]['name'] return os_disk, nic_name, ip_name def get_extension_version_now(resource_group, vmname, extension): vm_ext_out = json.loads(subprocess.check_output('az vm extension show --resource-group {0} --vm-name {1} --name {2} --expand instanceView'.format(resource_group, vmname, extension), shell=True)) installed_version = int(('').join(str(vm_ext_out["instanceView"]["typeHandlerVersion"]).split('.'))) return installed_version # Delete vm using AZ CLI def delete_vm(resource_group, vmname): os.system('az vm delete -g {0} -n {1} --yes {2}'.format(resource_group, vmname, runwith)) # Delete vm disk using AZ CLI def delete_vm_disk(resource_group, os_disk): os.system('az disk delete --resource-group {0} --name {1} --yes {2}'.format(resource_group, os_disk, runwith)) # Delete vm network interface using AZ CLI def delete_nic(resource_group, nic_name): os.system('az network nic delete --resource-group {0} --name {1} --no-wait {2}'.format(resource_group, nic_name, runwith)) # Delete vm ip from AZ CLI def delete_ip(resource_group, ip_name): os.system('az network public-ip delete --resource-group {0} --name {1} {2}'.format(resource_group, ip_name, runwith)) htmlstart = """<!DOCTYPE html> <html> <head> <style> table { font-family: arial, sans-serif; border-collapse: collapse; width: 100%; } table:not(th) { font-weight: lighter; } td, th { border: 1px solid #dddddd; text-align: left; padding: 8px; } tr:nth-child(even) { background-color: #dddddd; } </style> </head> <body> """ result_html_file.write(htmlstart) def main(): """Orchestrate fundemental testing steps onlined in header docstring.""" if is_instantupgrade: install_oms_msg = create_vm_and_install_old_extension() verify_oms_msg = verify_data() instantupgrade_status_msg = force_upgrade_extension() instantupgrade_verify_msg = verify_data() else: instantupgrade_verify_msg, instantupgrade_status_msg = None, None install_oms_msg = create_vm_and_install_extension() verify_oms_msg = verify_data() if is_autoupgrade: autoupgrade_status_msg = autoupgrade() autoupgrade_verify_msg = verify_data() else: autoupgrade_verify_msg, autoupgrade_status_msg = None, None remove_oms_msg = remove_extension() reinstall_oms_msg = reinstall_extension() if is_long: for i in reversed(range(1, LONG_DELAY + 1)): sys.stdout.write('\rLong-term delay: T-{0} minutes...'.format(i)) sys.stdout.flush() sleep(60) print('') long_status_msg = check_status() long_verify_msg = verify_data() else: long_verify_msg, long_status_msg = None, None remove_extension_and_delete_vm() messages = (install_oms_msg, verify_oms_msg, instantupgrade_verify_msg, instantupgrade_status_msg, autoupgrade_verify_msg, autoupgrade_status_msg, remove_oms_msg, reinstall_oms_msg, long_verify_msg, long_status_msg) create_report(messages) mv_result_files() def create_vm_and_install_extension(): """Create vm and install the extension, returning HTML results.""" message = "" update_option = "" install_times.clear() for distname, image in images.iteritems(): uid = rstr.xeger(r'[0-9a-f]{8}') vmname = distname.lower() + '-' + uid vmnames.append(vmname) dnsname = vmname vm_log_file = distname.lower() + "result.log" vm_html_file = distname.lower() + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') print("\nCreate VM and Install Extension - {0}: {1} \n".format(vmname, image)) create_vm(resource_group, vmname, image, username, ssh_public, location, dnsname, size, nsg_uri) dos_2_unix() copy_to_vm(dnsname, username, ssh_private, location) delete_extension(extension, vmname, resource_group) run_command(resource_group, vmname, 'RunShellScript', 'python -u /tmp/oms_extension_run_script.py -preinstall') add_extension(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -postinstall') install_times.update({vmname: datetime.now()}) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -injectlogs') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After Creating VM and Adding OMS Extension') html_open.write('<h1 id="{0}"> VM: {0} <h1>'.format(distname)) html_open.write("<h2> Install OMS Agent </h2>") append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status', 'r').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>Install Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>Install Failed</span></td>""" return message def create_vm_and_install_old_extension(): """Create vm and install a specific version of the extension, returning HTML results.""" message = "" update_option = '--version {0} --no-auto-upgrade'.format(old_version) install_times.clear() for distname, image in images.iteritems(): uid = rstr.xeger(r'[0-9a-f]{8}') vmname = distname.lower() + '-' + uid vmnames.append(vmname) dnsname = vmname vm_log_file = distname.lower() + "result.log" vm_html_file = distname.lower() + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') print("\nCreate VM and Install Extension {0} v-{1} - {2}: {3} \n".format(extension, old_version, vmname, image)) create_vm(resource_group, vmname, image, username, ssh_public, location, dnsname, size, nsg_uri) dos_2_unix() copy_to_vm(dnsname, username, ssh_private, location) delete_extension(extension, vmname, resource_group) run_command(resource_group, vmname, 'RunShellScript', 'python -u /tmp/oms_extension_run_script.py -preinstall') add_extension(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -postinstall') install_times.update({vmname: datetime.now()}) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -injectlogs') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, "Status After Creating VM and Adding OMS Extension version: {0}".format(old_version)) html_open.write('<h1 id="{0}"> VM: {0} <h1>'.format(distname)) html_open.write("<h2> Install OMS Agent version: {0} </h2>".format(old_version)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status', 'r').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>Install Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>Install Failed</span></td>""" return message def force_upgrade_extension(): """ Force Update the extension to the latest version """ message = "" update_option = '--force-update' install_times.clear() for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') dnsname = vmname print("\n Force Upgrade Extension: {0} \n".format(vmname)) add_extension(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -postinstall') install_times.update({vmname: datetime.now()}) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -injectlogs') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After Force Upgrading OMS Extension') html_open.write('<h2> Force Upgrade Extension: {0} <h2>'.format(vmname)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>Reinstall Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>Reinstall Failed</span></td>""" return message def verify_data(): """Verify data end-to-end, returning HTML results.""" message = "" for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') # Delay to allow data to propagate while datetime.now() < (install_times[vmname] + timedelta(minutes=E2E_DELAY)): mins, secs = get_time_diff(datetime.now(), install_times[vmname] + timedelta(minutes=E2E_DELAY)) sys.stdout.write('\rE2E propagation delay: {0} minutes {1} seconds...'.format(mins, secs)) sys.stdout.flush() sleep(1) print('') minutes, _ = get_time_diff(install_times[vmname], datetime.now()) timespan = 'PT{0}M'.format(minutes) data = check_e2e(vmname, timespan) # write detailed table for vm html_open.write("<h2> Verify Data from OMS workspace </h2>") write_log_command(log_open, 'Status After Verifying Data') results = data[distname][0] log_open.write(distname + ':\n' + json.dumps(results, indent=4, separators=(',', ': ')) + '\n') # prepend distro column to results row before generating the table data = [OrderedDict([('Distro', distname)] + results.items())] out = json2html.convert(data) html_open.write(out) # write to summary table from verify_e2e import success_count if success_count == 6: message += """ <td><span style='background-color: #66ff99'>Verify Success</td>""" elif 0 < success_count < 6: from verify_e2e import success_sources, failed_sources message += """ <td><span style='background-color: #66ff99'>{0} Success</span> <br><br><span style='background-color: red; color: white'>{1} Failed</span></td>""".format(', '.join(success_sources), ', '.join(failed_sources)) elif success_count == 0: message += """ <td><span style='background-color: red; color: white'>Verify Failed</span></td>""" return message def autoupgrade(): """ Waits for the extension to get updated automatically and continues with the tests after. Maximum wait time is 26 hours """ message = "" install_times.clear() for vmname in vmnames: initial_version = get_extension_version_now(resource_group, vmname, extension) time_lapsed = 0 while initial_version >= get_extension_version_now(resource_group, vmname, extension): sleep(AUTOUPGRADE_DELAY*60) time_lapsed+=AUTOUPGRADE_DELAY if time_lapsed < 1440: sys.stdout.write("waiting for new version. Time Lapsed: {0} minutes".format(time_lapsed)) sys.stdout.flush() elif 1440 <= time_lapsed < 1560: sys.stdout.write('Process waiting for more than 24 hrs. Please check the deployment of the new version is completed or not. This wait will end in {0} minutes'.format(1560 - time_lapsed)) sys.stdout.flush() elif time_lapsed >= 1560: print("""Process waiting for more than 26 hrs. No New version of extension has been deployed. If a new version is deployed, please check for any errors and re-run""") break distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') dnsname = vmname print("\n Checking Status After AutoUpgrade: {0} \n".format(vmname)) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -postinstall') install_times.update({vmname: datetime.now()}) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -injectlogs') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After AutoUpgrade OMS Extension') html_open.write('<h2> Status After AutoUpgrade OMS Extension: {0} <h2>'.format(vmname)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>AutoUpgrade Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>AutoUpgrade Failed</span></td>""" return message def remove_extension(): """Remove the extension, returning HTML results.""" message = "" for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') dnsname = vmname run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -copyomslogs') print("\nRemove Extension: {0} \n".format(vmname)) delete_extension(extension, vmname, resource_group) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -status') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After Removing OMS Extension') html_open.write('<h2> Remove Extension: {0} <h2>'.format(vmname)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status', 'r').read() if status == "Agent Found": message += """ <td><span style="background-color: red; color: white">Remove Failed</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style="background-color: red; color: white">Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style="background-color: #66ff99">Remove Success</span></td>""" return message def reinstall_extension(): """Reinstall the extension, returning HTML results.""" update_option = '--force-update' message = "" for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') dnsname = vmname print("\n Reinstall Extension: {0} \n".format(vmname)) add_extension(extension, publisher, vmname, resource_group, private_settings, public_settings, update_option) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -postinstall') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After Reinstall OMS Extension') html_open.write('<h2> Reinstall Extension: {0} <h2>'.format(vmname)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>Reinstall Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>Reinstall Failed</span></td>""" return message def check_status(): """Check agent status.""" message = "" install_times.clear() for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" vm_html_file = distname + "result.html" log_open = open(vm_log_file, 'a+') html_open = open(vm_html_file, 'a+') dnsname = vmname print("\n Checking Status: {0} \n".format(vmname)) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -status') install_times.update({vmname: datetime.now()}) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -injectlogs') copy_from_vm(dnsname, username, ssh_private, location, 'omsresults.*') write_log_command(log_open, 'Status After Long Run OMS Extension') html_open.write('<h2> Status After Long Run OMS Extension: {0} <h2>'.format(vmname)) append_file('omsfiles/omsresults.log', log_open) append_file('omsfiles/omsresults.html', html_open) log_open.close() html_open.close() status = open('omsfiles/omsresults.status').read() if status == "Agent Found": message += """ <td><span style='background-color: #66ff99'>Reinstall Success</span></td>""" elif status == "Onboarding Failed": message += """ <td><span style='background-color: red; color: white'>Onboarding Failed</span></td>""" elif status == "Agent Not Found": message += """ <td><span style='background-color: red; color: white'>Reinstall Failed</span></td>""" return message def remove_extension_and_delete_vm(): """Remove extension and delete vm.""" for vmname in vmnames: distname = vmname.split('-')[0] vm_log_file = distname + "result.log" log_open = open(vm_log_file, 'a+') dnsname = vmname run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -copyomslogs') copy_from_vm(dnsname, username, ssh_private, location, '{0}-omsagent.log'.format(distname)) print("\n Remove extension and Delete VM: {0} \n".format(vmname)) delete_extension(extension, vmname, resource_group) run_command(resource_group, vmname, 'RunShellScript', 'python -u /home/scratch/oms_extension_run_script.py -copyextlogs') copy_from_vm(dnsname, username, ssh_private, location, '{0}-extnwatcher.log'.format(distname)) disk, nic, ip = get_vm_resources(resource_group, vmname) delete_vm(resource_group, vmname) delete_vm_disk(resource_group, disk) delete_nic(resource_group, nic) delete_ip(resource_group, ip) append_file('omsfiles/{0}-extnwatcher.log'.format(distname), log_open) append_file('omsfiles/{0}-omsagent.log'.format(distname), log_open) log_open.close() def create_report(messages): """Compile the final HTML report.""" install_oms_msg, verify_oms_msg, instantupgrade_verify_msg, instantupgrade_status_msg, autoupgrade_verify_msg, autoupgrade_status_msg, remove_oms_msg, reinstall_oms_msg, long_verify_msg, long_status_msg = messages result_log_file = open("finalresult.log", "a+") # summary table diststh = "" resultsth = "" for vmname in vmnames: distname = vmname.split('-')[0] diststh += """ <th>{0}</th>""".format(distname) resultsth += """ <th><a href='#{0}'>{0} results</a></th>""".format(distname) if instantupgrade_verify_msg and instantupgrade_status_msg: instantupgrade_summary = """ <tr> <td>Instant Upgrade Verify Data</td> {0} </tr> <tr> <td>Instant Upgrade Status</td> {1} </tr> """.format(instantupgrade_verify_msg, instantupgrade_status_msg) else: instantupgrade_summary = "" if autoupgrade_verify_msg and autoupgrade_status_msg: autoupgrade_summary = """ <tr> <td>AutoUpgrade Verify Data</td> {0} </tr> <tr> <td>AutoUpgrade Status</td> {1} </tr> """.format(autoupgrade_verify_msg, autoupgrade_status_msg) else: autoupgrade_summary = "" # pre-compile long-running summary if long_verify_msg and long_status_msg: long_running_summary = """ <tr> <td>Long-Term Verify Data</td> {0} </tr> <tr> <td>Long-Term Status</td> {1} </tr> """.format(long_verify_msg, long_status_msg) else: long_running_summary = "" statustable = """ <table> <caption><h2>Test Result Table</h2><caption> <tr> <th>Distro</th> {0} </tr> <tr> <td>Install OMSAgent</td> {1} </tr> <tr> <td>Verify Data</td> {2} </tr> {3} {4} <tr> <td>Remove OMSAgent</td> {5} </tr> <tr> <td>Reinstall OMSAgent</td> {6} </tr> {7} <tr> <td>Result Link</td> {8} <tr> </table> """.format(diststh, install_oms_msg, verify_oms_msg, instantupgrade_summary, autoupgrade_summary, remove_oms_msg, reinstall_oms_msg, long_running_summary, resultsth) result_html_file.write(statustable) # Create final html & log file for vmname in vmnames: distname = vmname.split('-')[0] append_file(distname + "result.log", result_log_file) append_file(distname + "result.html", result_html_file) result_log_file.close() htmlend = """ </body> </html> """ result_html_file.write(htmlend) result_html_file.close() def mv_result_files(): if not os.path.exists('results'): os.makedirs('results') file_types = ['*result.*', 'omsfiles/*-extnwatcher.log', 'omsfiles/*-omsagent.log'] for files in file_types: for f in glob.glob(files): shutil.move(os.path.join(f), os.path.join('results/')) if __name__ == '__main__': main() ================================================ FILE: OmsAgent/extension-test/omsfiles/apache_access.log ================================================ 41.88.172.43 - - [18/Oct/2018:00:34:19 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4955 "http://phillips.org/homepage/" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5342 (KHTML, like Gecko) Chrome/15.0.818.0 Safari/5342" 97.77.75.235 - - [18/Oct/2018:00:39:13 +0000] "PUT /list HTTP/1.0" 200 5022 "http://www.goodwin.com/login.htm" "Mozilla/5.0 (Windows 98; Win 9x 4.90; it-IT; rv:1.9.1.20) Gecko/2010-05-09 00:24:19 Firefox/3.6.7" 6.33.183.64 - - [18/Oct/2018:00:42:47 +0000] "PUT /apps/cart.jsp?appID=7380 HTTP/1.0" 200 4961 "http://hess-jones.com/categories/register.html" "Mozilla/5.0 (Windows CE) AppleWebKit/5332 (KHTML, like Gecko) Chrome/13.0.883.0 Safari/5332" 50.159.139.180 - - [18/Oct/2018:00:45:04 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4928 "http://www.campbell-farrell.biz/wp-content/tag/blog/index/" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_4) AppleWebKit/5322 (KHTML, like Gecko) Chrome/15.0.818.0 Safari/5322" 82.191.47.90 - - [18/Oct/2018:00:50:01 +0000] "GET /wp-admin HTTP/1.0" 200 4931 "http://erickson.net/main/" "Mozilla/5.0 (iPod; U; CPU iPhone OS 4_2 like Mac OS X; sl-SI) AppleWebKit/533.19.5 (KHTML, like Gecko) Version/4.0.5 Mobile/8B113 Safari/6533.19.5" 154.50.38.159 - - [18/Oct/2018:00:51:46 +0000] "GET /wp-content HTTP/1.0" 200 4979 "http://www.palmer.com/index.php" "Mozilla/5.0 (iPod; U; CPU iPhone OS 3_3 like Mac OS X; sl-SI) AppleWebKit/533.7.4 (KHTML, like Gecko) Version/3.0.5 Mobile/8B117 Safari/6533.7.4" 140.54.30.228 - - [18/Oct/2018:00:53:07 +0000] "POST /wp-content HTTP/1.0" 200 5041 "http://www.sharp-kidd.com/faq.php" "Mozilla/5.0 (Windows 98; en-US; rv:1.9.0.20) Gecko/2015-01-02 20:55:20 Firefox/3.8" 29.153.222.134 - - [18/Oct/2018:00:54:28 +0000] "GET /search/tag/list HTTP/1.0" 200 5007 "http://martinez.com/list/wp-content/post/" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/5321 (KHTML, like Gecko) Chrome/14.0.827.0 Safari/5321" 89.21.173.120 - - [18/Oct/2018:00:57:37 +0000] "GET /posts/posts/explore HTTP/1.0" 301 4979 "http://kane.org/homepage/" "Mozilla/5.0 (compatible; MSIE 5.0; Windows NT 5.1; Trident/4.1)" 237.144.104.90 - - [18/Oct/2018:00:58:23 +0000] "PUT /app/main/posts HTTP/1.0" 404 5031 "http://www.guerrero-schroeder.com/list/categories/search/" "Mozilla/5.0 (Windows 98; Win 9x 4.90; en-US; rv:1.9.2.20) Gecko/2015-11-28 04:08:22 Firefox/3.6.3" 239.4.131.80 - - [18/Oct/2018:00:59:46 +0000] "PUT /explore HTTP/1.0" 404 5005 "http://www.sanchez.com/terms/" "Mozilla/5.0 (Windows CE; sl-SI; rv:1.9.1.20) Gecko/2017-05-24 17:33:32 Firefox/3.8" 72.190.211.123 - - [18/Oct/2018:01:04:24 +0000] "GET /posts/posts/explore HTTP/1.0" 200 5015 "http://www.byrd-kerr.com/home/" "Mozilla/5.0 (Windows 95; it-IT; rv:1.9.0.20) Gecko/2018-03-15 01:33:17 Firefox/13.0" 91.80.110.133 - - [18/Oct/2018:01:05:39 +0000] "GET /apps/cart.jsp?appID=3963 HTTP/1.0" 200 4933 "http://guzman.org/" "Mozilla/5.0 (Windows 98; sl-SI; rv:1.9.2.20) Gecko/2016-11-21 02:54:31 Firefox/3.8" 149.90.53.105 - - [18/Oct/2018:01:06:50 +0000] "DELETE /wp-content HTTP/1.0" 200 4996 "http://www.schroeder.com/privacy/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_5_0 rv:2.0; en-US) AppleWebKit/533.3.1 (KHTML, like Gecko) Version/5.0.2 Safari/533.3.1" 78.40.84.134 - - [18/Oct/2018:01:09:35 +0000] "GET /apps/cart.jsp?appID=9468 HTTP/1.0" 200 5050 "http://thomas-smith.biz/" "Mozilla/5.0 (X11; Linux x86_64; rv:1.9.5.20) Gecko/2012-05-17 06:37:48 Firefox/6.0" 30.224.7.147 - - [18/Oct/2018:01:13:04 +0000] "GET /wp-content HTTP/1.0" 200 4980 "http://www.ayala-rodriguez.net/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_6_3; rv:1.9.4.20) Gecko/2012-09-14 05:10:58 Firefox/4.0" 190.241.3.20 - - [18/Oct/2018:01:14:10 +0000] "GET /explore HTTP/1.0" 200 4875 "http://www.johnson.com/login.php" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/5320 (KHTML, like Gecko) Chrome/13.0.829.0 Safari/5320" 245.197.148.127 - - [18/Oct/2018:01:15:28 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4976 "http://griffith-miller.org/home.php" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_7_3 rv:2.0; en-US) AppleWebKit/535.46.3 (KHTML, like Gecko) Version/5.0.4 Safari/535.46.3" 90.8.112.249 - - [18/Oct/2018:01:20:27 +0000] "GET /apps/cart.jsp?appID=3243 HTTP/1.0" 200 4991 "http://holland-brown.com/terms/" "Opera/8.65.(Windows NT 5.0; it-IT) Presto/2.9.161 Version/10.00" 131.127.138.38 - - [18/Oct/2018:01:23:07 +0000] "GET /app/main/posts HTTP/1.0" 200 5060 "http://ware-cole.net/wp-content/tag/main/login/" "Mozilla/5.0 (Windows 95; en-US; rv:1.9.2.20) Gecko/2012-01-02 22:56:38 Firefox/3.6.12" 164.26.165.230 - - [18/Oct/2018:01:24:10 +0000] "DELETE /wp-admin HTTP/1.0" 200 5028 "http://harding-murphy.biz/author.html" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5351 (KHTML, like Gecko) Chrome/13.0.868.0 Safari/5351" 5.3.62.184 - - [18/Oct/2018:01:25:09 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4928 "http://www.stafford-hill.biz/" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5350 (KHTML, like Gecko) Chrome/15.0.815.0 Safari/5350" 181.162.5.173 - - [18/Oct/2018:01:26:41 +0000] "GET /list HTTP/1.0" 200 5038 "http://www.hall.com/posts/index/" "Mozilla/5.0 (Macintosh; PPC Mac OS X 10_7_3) AppleWebKit/5322 (KHTML, like Gecko) Chrome/15.0.880.0 Safari/5322" 98.153.76.19 - - [18/Oct/2018:01:28:10 +0000] "DELETE /wp-admin HTTP/1.0" 200 5018 "http://www.shaw-cole.com/" "Opera/9.15.(X11; Linux i686; sl-SI) Presto/2.9.164 Version/10.00" 127.70.246.76 - - [18/Oct/2018:01:29:45 +0000] "PUT /wp-admin HTTP/1.0" 200 5071 "http://www.french.net/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.6.20) Gecko/2013-12-07 23:58:36 Firefox/3.6.17" 244.88.20.30 - - [18/Oct/2018:01:30:27 +0000] "PUT /posts/posts/explore HTTP/1.0" 200 5029 "http://watson.info/register.html" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_7_4 rv:5.0; sl-SI) AppleWebKit/534.44.6 (KHTML, like Gecko) Version/5.0.4 Safari/534.44.6" 36.196.205.161 - - [18/Oct/2018:01:31:44 +0000] "POST /wp-content HTTP/1.0" 200 4948 "http://harris.com/app/terms/" "Mozilla/5.0 (X11; Linux x86_64; rv:1.9.6.20) Gecko/2014-10-08 23:33:35 Firefox/3.6.17" 176.75.22.168 - - [18/Oct/2018:01:32:19 +0000] "GET /posts/posts/explore HTTP/1.0" 200 5058 "http://conley.biz/tags/login.htm" "Mozilla/5.0 (Windows NT 6.2) AppleWebKit/5312 (KHTML, like Gecko) Chrome/13.0.844.0 Safari/5312" 241.178.144.215 - - [18/Oct/2018:01:36:36 +0000] "PUT /wp-content HTTP/1.0" 200 5028 "http://holden.com/login/" "Mozilla/5.0 (Windows CE) AppleWebKit/5330 (KHTML, like Gecko) Chrome/13.0.824.0 Safari/5330" 90.204.24.160 - - [18/Oct/2018:01:37:23 +0000] "PUT /list HTTP/1.0" 200 4969 "http://www.proctor-simmons.info/categories/author/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.7.20) Gecko/2013-10-05 02:04:20 Firefox/3.6.6" 246.240.89.237 - - [18/Oct/2018:01:39:12 +0000] "PUT /wp-content HTTP/1.0" 200 4986 "http://bailey.org/explore/wp-content/main.php" "Mozilla/5.0 (compatible; MSIE 9.0; Windows 98; Trident/5.1)" 5.76.9.164 - - [18/Oct/2018:01:40:38 +0000] "DELETE /app/main/posts HTTP/1.0" 200 5045 "http://www.buck.info/categories/wp-content/homepage.htm" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/5362 (KHTML, like Gecko) Chrome/15.0.847.0 Safari/5362" 197.98.233.63 - - [18/Oct/2018:01:42:03 +0000] "GET /explore HTTP/1.0" 200 5034 "http://www.martin-howard.com/explore/search/privacy.html" "Mozilla/5.0 (Windows CE; it-IT; rv:1.9.2.20) Gecko/2012-01-07 08:08:21 Firefox/12.0" 123.26.215.34 - - [18/Oct/2018:01:43:13 +0000] "DELETE /search/tag/list HTTP/1.0" 200 4997 "http://frazier-schmidt.com/main/main/" "Mozilla/5.0 (iPod; U; CPU iPhone OS 3_2 like Mac OS X; en-US) AppleWebKit/534.18.7 (KHTML, like Gecko) Version/3.0.5 Mobile/8B117 Safari/6534.18.7" 93.167.30.46 - - [18/Oct/2018:01:45:30 +0000] "POST /posts/posts/explore HTTP/1.0" 200 4914 "http://www.stanley-evans.com/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.7.20) Gecko/2018-05-29 09:05:34 Firefox/15.0" 0.12.43.164 - - [18/Oct/2018:01:49:46 +0000] "GET /explore HTTP/1.0" 200 4924 "http://allison.com/app/explore/app/main/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.5.20) Gecko/2013-09-24 01:40:33 Firefox/3.8" 28.33.105.197 - - [18/Oct/2018:01:50:50 +0000] "GET /app/main/posts HTTP/1.0" 200 4970 "http://www.jennings.com/categories/homepage.html" "Mozilla/5.0 (Windows NT 5.2) AppleWebKit/5310 (KHTML, like Gecko) Chrome/14.0.876.0 Safari/5310" 199.171.27.50 - - [18/Oct/2018:01:54:52 +0000] "GET /app/main/posts HTTP/1.0" 200 4934 "http://www.sanders-shah.net/tag/post.php" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5362 (KHTML, like Gecko) Chrome/15.0.897.0 Safari/5362" 134.36.90.225 - - [18/Oct/2018:01:56:10 +0000] "DELETE /list HTTP/1.0" 200 5053 "http://smith-rodriguez.com/explore/privacy/" "Mozilla/5.0 (iPod; U; CPU iPhone OS 4_0 like Mac OS X; it-IT) AppleWebKit/535.42.4 (KHTML, like Gecko) Version/4.0.5 Mobile/8B115 Safari/6535.42.4" 46.187.133.243 - - [18/Oct/2018:01:57:08 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4952 "http://henson.net/categories/blog/author/" "Mozilla/5.0 (iPod; U; CPU iPhone OS 4_2 like Mac OS X; it-IT) AppleWebKit/532.7.1 (KHTML, like Gecko) Version/4.0.5 Mobile/8B117 Safari/6532.7.1" 140.81.9.137 - - [18/Oct/2018:02:00:59 +0000] "POST /explore HTTP/1.0" 200 4927 "http://www.simmons.org/privacy.php" "Mozilla/5.0 (Windows CE; sl-SI; rv:1.9.2.20) Gecko/2013-07-30 04:10:19 Firefox/3.8" 31.107.2.231 - - [18/Oct/2018:02:02:48 +0000] "DELETE /list HTTP/1.0" 200 5006 "http://bray.biz/" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_8_1; rv:1.9.4.20) Gecko/2012-02-25 04:00:09 Firefox/3.6.15" 48.156.39.28 - - [18/Oct/2018:02:04:12 +0000] "GET /apps/cart.jsp?appID=6250 HTTP/1.0" 301 4988 "http://www.lopez.com/" "Mozilla/5.0 (compatible; MSIE 6.0; Windows CE; Trident/3.1)" 231.240.140.141 - - [18/Oct/2018:02:05:25 +0000] "GET /app/main/posts HTTP/1.0" 200 4976 "http://chandler.com/faq.html" "Mozilla/5.0 (Windows 98) AppleWebKit/5361 (KHTML, like Gecko) Chrome/13.0.807.0 Safari/5361" 197.49.239.55 - - [18/Oct/2018:02:07:56 +0000] "PUT /posts/posts/explore HTTP/1.0" 200 5053 "http://smith.biz/posts/wp-content/posts/main/" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_8_5 rv:6.0; en-US) AppleWebKit/533.15.7 (KHTML, like Gecko) Version/4.0.3 Safari/533.15.7" 185.165.245.238 - - [18/Oct/2018:02:11:59 +0000] "GET /list HTTP/1.0" 200 5010 "http://mercado.info/faq.html" "Mozilla/5.0 (X11; Linux i686; rv:1.9.6.20) Gecko/2011-04-30 23:25:27 Firefox/3.6.4" 65.96.205.50 - - [18/Oct/2018:02:12:34 +0000] "GET /app/main/posts HTTP/1.0" 200 5063 "http://brown.net/list/category/category/faq/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.6.20) Gecko/2015-01-26 02:39:38 Firefox/3.8" 149.35.179.83 - - [18/Oct/2018:02:14:51 +0000] "DELETE /wp-content HTTP/1.0" 200 4893 "http://www.adams-perkins.com/home/" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5342 (KHTML, like Gecko) Chrome/14.0.825.0 Safari/5342" 125.6.78.177 - - [18/Oct/2018:02:17:05 +0000] "PUT /posts/posts/explore HTTP/1.0" 404 4994 "http://www.king.com/author.html" "Mozilla/5.0 (Windows NT 6.2) AppleWebKit/5320 (KHTML, like Gecko) Chrome/13.0.873.0 Safari/5320" 107.154.205.58 - - [18/Oct/2018:02:19:19 +0000] "GET /apps/cart.jsp?appID=8180 HTTP/1.0" 200 4996 "http://www.morgan.com/category/app/author.html" "Mozilla/5.0 (Windows; U; Windows NT 5.01) AppleWebKit/533.34.5 (KHTML, like Gecko) Version/4.0 Safari/533.34.5" 169.176.221.189 - - [18/Oct/2018:02:22:57 +0000] "GET /app/main/posts HTTP/1.0" 200 5015 "http://meadows.com/list/tag/app/post/" "Mozilla/5.0 (Windows NT 5.01) AppleWebKit/5322 (KHTML, like Gecko) Chrome/14.0.801.0 Safari/5322" 130.28.74.78 - - [18/Oct/2018:02:26:27 +0000] "POST /app/main/posts HTTP/1.0" 200 4963 "http://www.ashley-trujillo.info/author.html" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_5_5; rv:1.9.4.20) Gecko/2016-03-23 15:04:20 Firefox/5.0" 133.76.165.208 - - [18/Oct/2018:02:27:00 +0000] "GET /wp-admin HTTP/1.0" 200 4978 "http://www.frazier-schwartz.info/categories/app/app/post.jsp" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_7_9; rv:1.9.3.20) Gecko/2013-08-13 05:42:00 Firefox/14.0" 110.206.82.119 - - [18/Oct/2018:02:28:48 +0000] "GET /list HTTP/1.0" 301 4954 "http://www.mitchell.biz/author.php" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_5_6; rv:1.9.2.20) Gecko/2012-10-23 18:30:25 Firefox/3.8" 168.198.62.27 - - [18/Oct/2018:02:30:12 +0000] "GET /wp-content HTTP/1.0" 200 4979 "http://www.rodriguez.com/categories/post.php" "Mozilla/5.0 (Windows 98) AppleWebKit/5350 (KHTML, like Gecko) Chrome/13.0.884.0 Safari/5350" 14.117.101.228 - - [18/Oct/2018:02:31:00 +0000] "GET /wp-admin HTTP/1.0" 200 4952 "http://stein.info/main.php" "Mozilla/5.0 (Windows 98) AppleWebKit/5320 (KHTML, like Gecko) Chrome/14.0.873.0 Safari/5320" 124.225.54.86 - - [18/Oct/2018:02:31:48 +0000] "GET /app/main/posts HTTP/1.0" 200 4970 "http://kennedy.biz/category/app/posts/home.jsp" "Mozilla/5.0 (Windows NT 6.2; en-US; rv:1.9.1.20) Gecko/2012-05-31 15:55:28 Firefox/3.6.7" 155.191.142.109 - - [18/Oct/2018:02:36:19 +0000] "PUT /explore HTTP/1.0" 200 4982 "http://farmer.com/category/search.php" "Mozilla/5.0 (Windows 98; sl-SI; rv:1.9.2.20) Gecko/2015-10-31 05:03:45 Firefox/3.8" 113.111.34.186 - - [18/Oct/2018:02:39:34 +0000] "GET /explore HTTP/1.0" 200 4922 "http://www.powell.org/login/" "Mozilla/5.0 (Windows NT 5.0) AppleWebKit/5340 (KHTML, like Gecko) Chrome/14.0.857.0 Safari/5340" 4.87.205.98 - - [18/Oct/2018:02:44:29 +0000] "GET /wp-admin HTTP/1.0" 200 4982 "http://www.santos.com/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.7.20) Gecko/2018-03-18 03:36:36 Firefox/12.0" 110.63.127.229 - - [18/Oct/2018:02:46:05 +0000] "POST /list HTTP/1.0" 200 4964 "http://good.com/about.html" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_4; rv:1.9.3.20) Gecko/2015-05-23 23:12:44 Firefox/3.8" 243.18.176.206 - - [18/Oct/2018:02:48:21 +0000] "PUT /list HTTP/1.0" 200 5043 "http://stephens-baldwin.com/tags/register/" "Mozilla/5.0 (Windows NT 5.2; en-US; rv:1.9.2.20) Gecko/2011-12-23 01:29:18 Firefox/3.8" 168.225.235.180 - - [18/Oct/2018:02:50:31 +0000] "PUT /posts/posts/explore HTTP/1.0" 404 5043 "http://ortega.com/" "Mozilla/5.0 (Macintosh; PPC Mac OS X 10_8_3; rv:1.9.2.20) Gecko/2011-03-26 14:18:21 Firefox/7.0" 7.129.23.77 - - [18/Oct/2018:02:52:23 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4999 "http://lewis-bruce.com/category/explore/home/" "Mozilla/5.0 (Windows 98) AppleWebKit/5332 (KHTML, like Gecko) Chrome/15.0.852.0 Safari/5332" 201.131.130.135 - - [18/Oct/2018:02:53:07 +0000] "PUT /wp-content HTTP/1.0" 200 4943 "http://williams.com/" "Mozilla/5.0 (Windows 98; Win 9x 4.90; sl-SI; rv:1.9.1.20) Gecko/2014-01-14 15:51:52 Firefox/4.0" 160.40.4.98 - - [18/Oct/2018:02:54:00 +0000] "POST /posts/posts/explore HTTP/1.0" 200 5094 "http://lucas-west.com/blog/category/privacy/" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3; rv:1.9.5.20) Gecko/2018-01-10 05:44:15 Firefox/15.0" 233.25.14.15 - - [18/Oct/2018:02:57:06 +0000] "DELETE /wp-content HTTP/1.0" 404 4992 "http://www.adams-clayton.biz/tags/search/wp-content/search/" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3 rv:5.0; en-US) AppleWebKit/532.47.2 (KHTML, like Gecko) Version/4.0.2 Safari/532.47.2" 71.39.28.180 - - [18/Oct/2018:03:00:12 +0000] "GET /list HTTP/1.0" 200 4975 "http://walker.com/" "Mozilla/5.0 (X11; Linux i686; rv:1.9.6.20) Gecko/2016-07-11 18:18:38 Firefox/3.8" 51.153.212.169 - - [18/Oct/2018:03:04:56 +0000] "GET /app/main/posts HTTP/1.0" 301 4933 "http://chavez.com/list/categories/posts/terms/" "Mozilla/5.0 (Windows 98; Win 9x 4.90; sl-SI; rv:1.9.1.20) Gecko/2017-05-05 14:56:33 Firefox/3.6.11" 220.43.102.130 - - [18/Oct/2018:03:07:41 +0000] "DELETE /apps/cart.jsp?appID=6069 HTTP/1.0" 200 4936 "http://douglas.com/homepage/" "Mozilla/5.0 (Windows NT 5.01; en-US; rv:1.9.2.20) Gecko/2016-07-30 00:30:02 Firefox/3.6.8" 139.218.49.46 - - [18/Oct/2018:03:11:02 +0000] "PUT /apps/cart.jsp?appID=6207 HTTP/1.0" 200 5043 "http://perez.com/home/" "Mozilla/5.0 (Windows NT 5.01; en-US; rv:1.9.1.20) Gecko/2012-10-30 07:57:39 Firefox/11.0" 226.16.197.119 - - [18/Oct/2018:03:12:44 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4979 "http://www.jones.com/main.asp" "Mozilla/5.0 (Macintosh; PPC Mac OS X 10_8_0; rv:1.9.6.20) Gecko/2011-01-20 10:45:37 Firefox/3.6.20" 40.119.68.10 - - [18/Oct/2018:03:14:45 +0000] "DELETE /posts/posts/explore HTTP/1.0" 200 4964 "http://www.ortega.com/author.htm" "Mozilla/5.0 (X11; Linux x86_64; rv:1.9.5.20) Gecko/2011-07-22 13:30:52 Firefox/3.6.15" 168.208.53.165 - - [18/Oct/2018:03:16:06 +0000] "GET /explore HTTP/1.0" 200 5014 "http://www.perez-miller.com/category/privacy.html" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_0; rv:1.9.5.20) Gecko/2011-01-02 02:10:51 Firefox/3.8" 11.52.197.212 - - [18/Oct/2018:03:20:15 +0000] "DELETE /explore HTTP/1.0" 500 5042 "http://www.rubio.info/post/" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_7_5 rv:3.0; it-IT) AppleWebKit/535.43.7 (KHTML, like Gecko) Version/4.0.5 Safari/535.43.7" 21.7.60.251 - - [18/Oct/2018:03:22:55 +0000] "GET /list HTTP/1.0" 301 5072 "http://stevenson.org/index/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_7_3; rv:1.9.6.20) Gecko/2012-04-25 16:25:18 Firefox/10.0" 20.111.53.239 - - [18/Oct/2018:03:27:08 +0000] "POST /apps/cart.jsp?appID=2130 HTTP/1.0" 200 4990 "http://www.hunt-raymond.com/main/search/" "Mozilla/5.0 (Windows 95; en-US; rv:1.9.0.20) Gecko/2016-01-26 09:23:42 Firefox/3.8" 163.155.239.245 - - [18/Oct/2018:03:27:45 +0000] "DELETE /posts/posts/explore HTTP/1.0" 200 4939 "http://adkins.com/categories/posts/search/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_7_1; rv:1.9.2.20) Gecko/2010-12-04 23:04:59 Firefox/3.8" 133.140.40.180 - - [18/Oct/2018:03:31:50 +0000] "GET /posts/posts/explore HTTP/1.0" 200 4979 "http://www.malone.com/tags/about.php" "Opera/8.30.(Windows CE; sl-SI) Presto/2.9.189 Version/10.00" 199.152.210.117 - - [18/Oct/2018:03:34:19 +0000] "GET /wp-admin HTTP/1.0" 200 4986 "http://www.camacho.com/explore/search/faq.php" "Mozilla/5.0 (X11; Linux i686; rv:1.9.7.20) Gecko/2018-09-05 13:26:24 Firefox/3.8" 180.22.162.153 - - [18/Oct/2018:03:38:22 +0000] "GET /apps/cart.jsp?appID=2363 HTTP/1.0" 200 4993 "http://dean-cherry.com/homepage.php" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/5330 (KHTML, like Gecko) Chrome/14.0.850.0 Safari/5330" 67.123.236.154 - - [18/Oct/2018:03:40:26 +0000] "PUT /search/tag/list HTTP/1.0" 200 4972 "http://powers.com/home.jsp" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_5_3; rv:1.9.3.20) Gecko/2010-02-12 23:31:10 Firefox/3.6.5" 70.12.70.204 - - [18/Oct/2018:03:45:02 +0000] "DELETE /explore HTTP/1.0" 200 4937 "http://www.romero.com/tag/explore/main.html" "Mozilla/5.0 (X11; Linux i686; rv:1.9.6.20) Gecko/2016-05-05 13:52:25 Firefox/3.8" 85.226.90.231 - - [18/Oct/2018:03:49:37 +0000] "GET /app/main/posts HTTP/1.0" 200 4989 "http://www.bradley-bailey.com/faq.asp" "Mozilla/5.0 (X11; Linux i686) AppleWebKit/5351 (KHTML, like Gecko) Chrome/13.0.880.0 Safari/5351" 235.65.112.180 - - [18/Oct/2018:03:52:37 +0000] "GET /search/tag/list HTTP/1.0" 200 5030 "http://www.mendoza.com/tags/blog/tag/category.html" "Mozilla/5.0 (Windows CE; en-US; rv:1.9.0.20) Gecko/2011-08-04 17:20:08 Firefox/14.0" 158.47.154.156 - - [18/Oct/2018:03:54:48 +0000] "GET /list HTTP/1.0" 200 4998 "http://www.campos.com/search.php" "Mozilla/5.0 (compatible; MSIE 6.0; Windows CE; Trident/4.1)" 125.79.156.46 - - [18/Oct/2018:03:56:03 +0000] "GET /search/tag/list HTTP/1.0" 200 5010 "http://www.golden.com/post.htm" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_7_1; rv:1.9.2.20) Gecko/2014-10-07 01:16:43 Firefox/3.6.10" 184.232.250.128 - - [18/Oct/2018:04:00:14 +0000] "PUT /wp-content HTTP/1.0" 200 4984 "http://www.turner.info/" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_6_4; rv:1.9.2.20) Gecko/2013-09-09 13:08:21 Firefox/15.0" 99.91.125.62 - - [18/Oct/2018:04:02:08 +0000] "PUT /posts/posts/explore HTTP/1.0" 200 5065 "http://www.leblanc.com/" "Mozilla/5.0 (Windows NT 6.2) AppleWebKit/5322 (KHTML, like Gecko) Chrome/15.0.851.0 Safari/5322" 44.235.108.106 - - [18/Oct/2018:04:05:59 +0000] "PUT /wp-content HTTP/1.0" 200 5103 "http://www.matthews.info/search/blog/main/category/" "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/5350 (KHTML, like Gecko) Chrome/13.0.818.0 Safari/5350" 137.133.193.233 - - [18/Oct/2018:04:09:31 +0000] "GET /wp-admin HTTP/1.0" 200 5008 "http://curtis.com/" "Mozilla/5.0 (Windows NT 5.1; en-US; rv:1.9.1.20) Gecko/2017-08-29 05:05:21 Firefox/3.6.5" 123.45.94.23 - - [18/Oct/2018:04:10:46 +0000] "GET /posts/posts/explore HTTP/1.0" 200 5048 "http://www.white-miller.com/search.htm" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_5_8 rv:2.0; it-IT) AppleWebKit/531.39.6 (KHTML, like Gecko) Version/5.0.1 Safari/531.39.6" 45.199.49.213 - - [18/Oct/2018:04:11:37 +0000] "GET /wp-content HTTP/1.0" 200 5078 "http://www.shaffer.info/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_6_0; rv:1.9.4.20) Gecko/2014-11-11 14:14:32 Firefox/3.6.4" 8.115.73.60 - - [18/Oct/2018:04:16:16 +0000] "GET /list HTTP/1.0" 200 5013 "http://moore.com/privacy.jsp" "Opera/9.90.(X11; Linux i686; en-US) Presto/2.9.190 Version/10.00" 8.36.203.85 - - [18/Oct/2018:04:20:02 +0000] "DELETE /wp-content HTTP/1.0" 200 4908 "http://martin.com/search/faq/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_6_6; rv:1.9.4.20) Gecko/2012-12-19 23:52:21 Firefox/3.6.2" 204.50.213.48 - - [18/Oct/2018:04:20:55 +0000] "GET /search/tag/list HTTP/1.0" 200 5016 "http://russell.com/post.html" "Mozilla/5.0 (Macintosh; PPC Mac OS X 10_7_6 rv:6.0; en-US) AppleWebKit/535.19.1 (KHTML, like Gecko) Version/5.0 Safari/535.19.1" 201.165.240.2 - - [18/Oct/2018:04:23:44 +0000] "GET /search/tag/list HTTP/1.0" 200 4885 "http://www.reynolds-hunter.com/app/explore/home/" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_5_1; rv:1.9.6.20) Gecko/2016-10-06 18:45:10 Firefox/3.6.11" 2.248.77.71 - - [18/Oct/2018:04:24:20 +0000] "GET /search/tag/list HTTP/1.0" 200 5014 "http://www.erickson.com/categories/posts/list/search/" "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_5_1; rv:1.9.2.20) Gecko/2010-07-10 10:14:15 Firefox/3.8" 180.166.108.32 - - [18/Oct/2018:04:27:51 +0000] "GET /search/tag/list HTTP/1.0" 200 4991 "http://www.wall.info/privacy.jsp" "Mozilla/5.0 (Windows NT 4.0) AppleWebKit/5310 (KHTML, like Gecko) Chrome/14.0.800.0 Safari/5310" 112.68.225.108 - - [18/Oct/2018:04:29:45 +0000] "GET /apps/cart.jsp?appID=1353 HTTP/1.0" 404 5023 "http://blair-miller.com/" "Mozilla/5.0 (Macintosh; U; PPC Mac OS X 10_6_0; rv:1.9.4.20) Gecko/2011-10-19 08:56:36 Firefox/9.0" ================================================ FILE: OmsAgent/extension-test/omsfiles/custom.log ================================================ 2018-10-18 01:47:10 We need to rent a room for our party. 2018-10-18 01:47:10 Yeah, I think it's a good environment for learning English. 2018-10-18 01:47:10 Everyone was busy, so I went to the movie alone. 2018-10-18 01:47:10 A purple pig and a green donkey flew a kite in the middle of the night and ended up sunburnt. 2018-10-18 01:47:10 Yeah, I think it's a good environment for learning English. 2018-10-18 01:47:10 I am never at home on Sundays. 2018-10-18 01:47:10 There were white out conditions in the town; subsequently, the roads were impassable. 2018-10-18 01:47:10 I hear that Nancy is very pretty. 2018-10-18 01:47:10 There were white out conditions in the town; subsequently, the roads were impassable. 2018-10-18 01:47:10 I am counting my calories, yet I really want dessert. 2018-10-18 01:47:10 A purple pig and a green donkey flew a kite in the middle of the night and ended up sunburnt. 2018-10-18 01:47:10 I am counting my calories, yet I really want dessert. 2018-10-18 01:47:10 Cats are good pets, for they are clean and are not noisy. 2018-10-18 01:47:10 Where do random thoughts come from? 2018-10-18 01:47:10 I am never at home on Sundays. 2018-10-18 01:47:10 The memory we used to share is no longer coherent. 2018-10-18 01:47:10 She folded her handkerchief neatly. 2018-10-18 01:47:10 The memory we used to share is no longer coherent. 2018-10-18 01:47:10 She was too short to see over the fence. 2018-10-18 01:47:10 Rock music approaches at high velocity. ================================================ FILE: OmsAgent/extension-test/omsfiles/customlog.conf ================================================ # This file is configured by the OMS service <source> type sudo_tail path /var/log/custom.log pos_file /var/opt/microsoft/omsagent/state/CUSTOM_LOG_BLOB.Custom_Log_CL_<workspace-id>.pos read_from_head true run_interval 60 tag oms.blob.CustomLog.CUSTOM_LOG_BLOB.Custom_Log_CL_<workspace-id>.* format none </source> ================================================ FILE: OmsAgent/extension-test/omsfiles/error.log ================================================ Version: '5.7.23-0ubuntu0.16.04.1-log' socket: '/var/run/mysqld/mysqld.sock' port: 3306 (Ubuntu) 2018-10-15T19:54:50.675213Z 2 [Note] Access denied for user 'root'@'localhost' (using password: NO) 2018-10-15T19:55:16.350986Z 0 [Note] Giving 7 client threads a chance to die gracefully 2018-10-15T19:55:16.351028Z 0 [Note] Shutting down slave threads 2018-10-15T19:55:18.351113Z 0 [Note] Forcefully disconnecting 4 remaining clients 2018-10-15T19:55:18.351141Z 0 [Warning] /usr/sbin/mysqld: Forcing close of thread 3 user: 'zabbix' 2018-10-15T19:55:18.351160Z 0 [Warning] /usr/sbin/mysqld: Forcing close of thread 4 user: 'zabbix' 2018-10-15T19:55:18.351186Z 0 [Warning] /usr/sbin/mysqld: Forcing close of thread 5 user: 'zabbix' 2018-10-15T19:55:18.351193Z 0 [Warning] /usr/sbin/mysqld: Forcing close of thread 9 user: 'zabbix' 2018-10-15T19:55:18.351205Z 0 [Note] Event Scheduler: Purging the queue. 0 events 2018-10-15T19:55:18.351518Z 0 [Note] Binlog end 2018-10-15T19:55:18.352972Z 0 [Note] Shutting down plugin 'ngram' 2018-10-15T19:55:18.352983Z 0 [Note] Shutting down plugin 'partition' 2018-10-15T19:55:18.352986Z 0 [Note] Shutting down plugin 'BLACKHOLE' 2018-10-15T19:55:18.352989Z 0 [Note] Shutting down plugin 'ARCHIVE' 2018-10-15T19:55:18.353039Z 0 [Note] Shutting down plugin 'MEMORY' 2018-10-15T19:55:18.353044Z 0 [Note] Shutting down plugin 'INNODB_SYS_VIRTUAL' 2018-10-15T19:55:18.353048Z 0 [Note] Shutting down plugin 'INNODB_SYS_DATAFILES' 2018-10-15T19:55:18.353051Z 0 [Note] Shutting down plugin 'INNODB_SYS_TABLESPACES' 2018-10-15T19:55:18.353053Z 0 [Note] Shutting down plugin 'INNODB_SYS_FOREIGN_COLS' 2018-10-15T19:55:18.353056Z 0 [Note] Shutting down plugin 'INNODB_SYS_FOREIGN' 2018-10-15T19:55:18.353059Z 0 [Note] Shutting down plugin 'INNODB_SYS_FIELDS' 2018-10-15T19:55:18.353062Z 0 [Note] Shutting down plugin 'INNODB_SYS_COLUMNS' 2018-10-15T19:55:18.353065Z 0 [Note] Shutting down plugin 'INNODB_SYS_INDEXES' 2018-10-15T19:55:18.353068Z 0 [Note] Shutting down plugin 'INNODB_SYS_TABLESTATS' 2018-10-15T19:55:18.353071Z 0 [Note] Shutting down plugin 'INNODB_SYS_TABLES' 2018-10-15T19:55:18.353073Z 0 [Note] Shutting down plugin 'INNODB_FT_INDEX_TABLE' 2018-10-15T19:55:18.353076Z 0 [Note] Shutting down plugin 'INNODB_FT_INDEX_CACHE' 2018-10-15T19:55:18.353079Z 0 [Note] Shutting down plugin 'INNODB_FT_CONFIG' 2018-10-15T19:55:18.353082Z 0 [Note] Shutting down plugin 'INNODB_FT_BEING_DELETED' 2018-10-15T19:55:18.353085Z 0 [Note] Shutting down plugin 'INNODB_FT_DELETED' 2018-10-15T19:55:18.353087Z 0 [Note] Shutting down plugin 'INNODB_FT_DEFAULT_STOPWORD' 2018-10-15T19:55:18.353091Z 0 [Note] Shutting down plugin 'INNODB_METRICS' 2018-10-15T19:55:18.353094Z 0 [Note] Shutting down plugin 'INNODB_TEMP_TABLE_INFO' 2018-10-15T19:55:18.353097Z 0 [Note] Shutting down plugin 'INNODB_BUFFER_POOL_STATS' 2018-10-15T19:55:18.353099Z 0 [Note] Shutting down plugin 'INNODB_BUFFER_PAGE_LRU' 2018-10-15T19:55:18.353102Z 0 [Note] Shutting down plugin 'INNODB_BUFFER_PAGE' 2018-10-15T19:55:18.353105Z 0 [Note] Shutting down plugin 'INNODB_CMP_PER_INDEX_RESET' 2018-10-15T19:55:18.353108Z 0 [Note] Shutting down plugin 'INNODB_CMP_PER_INDEX' 2018-10-15T19:55:18.353111Z 0 [Note] Shutting down plugin 'INNODB_CMPMEM_RESET' 2018-10-15T19:55:18.353113Z 0 [Note] Shutting down plugin 'INNODB_CMPMEM' 2018-10-15T19:55:18.353116Z 0 [Note] Shutting down plugin 'INNODB_CMP_RESET' 2018-10-15T19:55:18.353119Z 0 [Note] Shutting down plugin 'INNODB_CMP' 2018-10-15T19:55:18.353122Z 0 [Note] Shutting down plugin 'INNODB_LOCK_WAITS' 2018-10-15T19:55:18.353125Z 0 [Note] Shutting down plugin 'INNODB_LOCKS' 2018-10-15T19:55:18.353128Z 0 [Note] Shutting down plugin 'INNODB_TRX' 2018-10-15T19:55:18.353131Z 0 [Note] Shutting down plugin 'InnoDB' 2018-10-15T19:55:18.354768Z 0 [Note] InnoDB: FTS optimize thread exiting. 2018-10-15T19:55:18.354950Z 0 [Note] InnoDB: Starting shutdown... 2018-10-15T19:55:18.455235Z 0 [Note] InnoDB: Dumping buffer pool(s) to /var/lib/mysql/ib_buffer_pool 2018-10-15T19:55:18.455408Z 0 [Note] InnoDB: Buffer pool(s) dump completed at 181015 19:55:18 2018-10-15T19:55:19.782226Z 0 [Note] InnoDB: Shutdown completed; log sequence number 15199598 2018-10-15T19:55:19.784179Z 0 [Note] InnoDB: Removed temporary tablespace data file: "ibtmp1" 2018-10-15T19:55:19.784526Z 0 [Note] Shutting down plugin 'MRG_MYISAM' 2018-10-15T19:55:19.784782Z 0 [Note] Shutting down plugin 'MyISAM' 2018-10-15T19:55:19.785078Z 0 [Note] Shutting down plugin 'CSV' 2018-10-15T19:55:19.785417Z 0 [Note] Shutting down plugin 'PERFORMANCE_SCHEMA' 2018-10-15T19:55:19.785694Z 0 [Note] Shutting down plugin 'sha256_password' 2018-10-15T19:55:19.785937Z 0 [Note] Shutting down plugin 'mysql_native_password' 2018-10-15T19:55:19.786432Z 0 [Note] Shutting down plugin 'binlog' 2018-10-15T19:55:19.796160Z 0 [Note] /usr/sbin/mysqld: Shutdown complete 2018-10-15T19:55:19.858798Z 0 [Warning] Changed limits: max_open_files: 1024 (requested 5000) 2018-10-15T19:55:19.858828Z 0 [Warning] Changed limits: table_open_cache: 431 (requested 2000) 2018-10-15T19:55:20.017175Z 0 [Warning] TIMESTAMP with implicit DEFAULT value is deprecated. Please use --explicit_defaults_for_timestamp server option (see documentation for more details). 2018-10-15T19:55:20.019931Z 0 [Note] /usr/sbin/mysqld (mysqld 5.7.23-0ubuntu0.16.04.1-log) starting as process 12803 ... 2018-10-15T19:55:20.025911Z 0 [Note] InnoDB: PUNCH HOLE support available 2018-10-15T19:55:20.026301Z 0 [Note] InnoDB: Mutexes and rw_locks use GCC atomic builtins 2018-10-15T19:55:20.026621Z 0 [Note] InnoDB: Uses event mutexes 2018-10-15T19:55:20.026886Z 0 [Note] InnoDB: GCC builtin __atomic_thread_fence() is used for memory barrier 2018-10-15T19:55:20.027229Z 0 [Note] InnoDB: Compressed tables use zlib 1.2.8 2018-10-15T19:55:20.027559Z 0 [Note] InnoDB: Using Linux native AIO 2018-10-15T19:55:20.028218Z 0 [Note] InnoDB: Number of pools: 1 2018-10-15T19:55:20.028662Z 0 [Note] InnoDB: Using CPU crc32 instructions 2018-10-15T19:55:20.031387Z 0 [Note] InnoDB: Initializing buffer pool, total size = 128M, instances = 1, chunk size = 128M 2018-10-15T19:55:20.039710Z 0 [Note] InnoDB: Completed initialization of buffer pool 2018-10-15T19:55:20.042206Z 0 [Note] InnoDB: If the mysqld execution user is authorized, page cleaner thread priority can be changed. See the man page of setpriority(). 2018-10-15T19:55:20.054579Z 0 [Note] InnoDB: Highest supported file format is Barracuda. 2018-10-15T19:55:20.111399Z 0 [Note] InnoDB: Creating shared tablespace for temporary tables 2018-10-15T19:55:20.111920Z 0 [Note] InnoDB: Setting file './ibtmp1' size to 12 MB. Physically writing the file full; Please wait ... 2018-10-15T19:55:21.176656Z 0 [Note] InnoDB: File './ibtmp1' size is now 12 MB. 2018-10-15T19:55:21.177860Z 0 [Note] InnoDB: 96 redo rollback segment(s) found. 96 redo rollback segment(s) are active. 2018-10-15T19:55:21.178348Z 0 [Note] InnoDB: 32 non-redo rollback segment(s) are active. 2018-10-15T19:55:21.180824Z 0 [Note] InnoDB: Waiting for purge to start 2018-10-15T19:55:21.231391Z 0 [Note] InnoDB: 5.7.23 started; log sequence number 15199598 2018-10-15T19:55:21.232458Z 0 [Note] Plugin 'FEDERATED' is disabled. 2018-10-15T19:55:21.236600Z 0 [Note] InnoDB: Loading buffer pool(s) from /var/lib/mysql/ib_buffer_pool 2018-10-15T19:55:21.242685Z 0 [Warning] Failed to set up SSL because of the following SSL library error: SSL context is not usable without certificate and private key 2018-10-15T19:55:21.243032Z 0 [Note] Server hostname (bind-address): '127.0.0.1'; port: 3306 2018-10-15T19:55:21.243445Z 0 [Note] - '127.0.0.1' resolves to '127.0.0.1'; 2018-10-15T19:55:21.243731Z 0 [Note] Server socket created on IP: '127.0.0.1'. 2018-10-15T19:55:21.249632Z 0 [Note] InnoDB: Buffer pool(s) load completed at 181015 19:55:21 2018-10-15T19:55:21.254564Z 0 [Note] Event Scheduler: Loaded 0 events 2018-10-15T19:55:21.255500Z 0 [Note] /usr/sbin/mysqld: ready for connections. Version: '5.7.23-0ubuntu0.16.04.1-log' socket: '/var/run/mysqld/mysqld.sock' port: 3306 (Ubuntu) 2018-10-15T19:55:21.856724Z 2 [Note] Access denied for user 'root'@'localhost' (using password: NO) ================================================ FILE: OmsAgent/extension-test/omsfiles/mysql-slow.log ================================================ /usr/sbin/mysqld, Version: 5.7.23-0ubuntu0.16.04.1-log ((Ubuntu)). started with: Tcp port: 3306 Unix socket: /var/run/mysqld/mysqld.sock Time Id Command Argument # Time: 2018-10-15T19:56:12.584806Z # User@Host: root[root] @ localhost [] Id: 8 # Query_time: 10.000575 Lock_time: 0.000000 Rows_sent: 1 Rows_examined: 0 use test; SET timestamp=1539633372; select sleep(10); # Time: 2018-10-15T19:56:36.398796Z # User@Host: root[root] @ localhost [] Id: 8 # Query_time: 10.000546 Lock_time: 0.000000 Rows_sent: 1 Rows_examined: 0 SET timestamp=1539633396; select sleep(10); ================================================ FILE: OmsAgent/extension-test/omsfiles/mysql.log ================================================ 2018-10-15T18:38:50.315829Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:50.863314Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:38:50.863721Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628730 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:38:50.864137Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:38:51.317472Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:52.318953Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:52.947568Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:52.947842Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:52.947973Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:53.320549Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:54.322093Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:55.323700Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:55.864344Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:38:55.864883Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628735 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:38:55.865129Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:38:55.948479Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:55.948664Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:55.948814Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:56.325397Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:57.326948Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:58.328609Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:38:58.949362Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:58.949594Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:58.949700Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:38:59.330164Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:00.331671Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:00.847135Z 17 Query select m.maintenanceid,m.maintenance_type,m.active_since,tp.timeperiod_type,tp.every,tp.month,tp.dayofweek,tp.day,tp.start_time,tp.period,tp.start_date from maintenances m,maintenances_windows mw,timeperiods tp where m.maintenanceid=mw.maintenanceid and mw.timeperiodid=tp.timeperiodid and m.active_since<=1539628740 and m.active_till>1539628740 2018-10-15T18:39:00.847466Z 17 Query begin 2018-10-15T18:39:00.847520Z 17 Query select hostid,host,maintenance_type,maintenance_from from hosts where status=0 and flags<>2 and maintenance_status=1 2018-10-15T18:39:00.847687Z 17 Query commit 2018-10-15T18:39:00.865215Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:39:00.865899Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628740 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:00.866166Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:01.333227Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:01.950246Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:01.950520Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:01.950663Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:02.334772Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:03.336170Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:04.338531Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:04.951288Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:04.951515Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:04.951615Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:05.340207Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:05.866136Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:39:05.866969Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628745 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:05.867202Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:06.341692Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:07.343211Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:07.952124Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:07.952332Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:07.952436Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:08.344797Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:09.346291Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:10.347130Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:10.867114Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:39:10.867965Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628750 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:10.868197Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:10.953013Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:10.953242Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:10.953370Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:11.348383Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:12.348909Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:13.349494Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:13.953982Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:13.954400Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:13.954532Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:14.350097Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:14.444846Z 12 Query DROP DATABASE test 2018-10-15T18:39:14.576240Z 12 Query SELECT DATABASE() 2018-10-15T18:39:15.310599Z 12 Query CREATE DATABASE test 2018-10-15T18:39:15.350757Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:15.868143Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:39:15.869085Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628755 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:15.869995Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:16.352828Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:16.683314Z 12 Query SELECT DATABASE() 2018-10-15T18:39:16.684005Z 12 Init DB test 2018-10-15T18:39:16.685263Z 12 Query show databases 2018-10-15T18:39:16.685810Z 12 Query show tables 2018-10-15T18:39:16.955061Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:16.955820Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:16.956073Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:17.355676Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:17.769543Z 12 Query CREATE TABLE IF NOT EXISTS data (id INT AUTO_INCREMENT, title VARCHAR(255) NOT NULL, description TEXT, PRIMARY KEY (id)) 2018-10-15T18:39:18.357242Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:19.358825Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:19.956729Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:19.957576Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:19.958050Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:20.360434Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:20.870503Z 21 Query select taskid,type,clock,ttl from task where status in (1,2) order by taskid 2018-10-15T18:39:20.871246Z 16 Query select h.hostid,h.host,h.name,t.httptestid,t.name,t.agent,t.authentication,t.http_user,t.http_password,t.http_proxy,t.retries,t.ssl_cert_file,t.ssl_key_file,t.ssl_key_password,t.verify_peer,t.verify_host,t.delay from httptest t,hosts h where t.hostid=h.hostid and t.nextcheck<=1539628760 and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:20.871544Z 16 Query select min(t.nextcheck) from httptest t,hosts h where t.hostid=h.hostid and mod(t.httptestid,1)=0 and t.status=0 and h.proxy_hostid is null and h.status=0 and (h.maintenance_status=0 or h.maintenance_type=0) 2018-10-15T18:39:21.361831Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:22.363301Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid 2018-10-15T18:39:22.958623Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:22.958924Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where itemid is not null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:22.959097Z 28 Query select escalationid,actionid,triggerid,eventid,r_eventid,nextcheck,esc_step,status,itemid,acknowledgeid from escalations where triggerid is null and itemid is null order by actionid,triggerid,itemid,escalationid 2018-10-15T18:39:23.364725Z 33 Query select a.alertid,a.mediatypeid,a.sendto,a.subject,a.message,a.status,a.retries,e.source,e.object,e.objectid from alerts a left join events e on a.eventid=e.eventid where alerttype=0 and a.status=3 order by a.alertid ================================================ FILE: OmsAgent/extension-test/omsfiles/oms_extension_run_script.py ================================================ import datetime import os import os.path import platform import re import subprocess import sys import time if "check_output" not in dir(subprocess): # duck punch it in! def check_output(*popenargs, **kwargs): r"""Run command with arguments and return its output as a byte string. Backported from Python 2.7 as it's implemented as pure python on stdlib. >>> check_output(['/usr/bin/python', '--version']) Python 2.6.2 """ process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] error = subprocess.CalledProcessError(retcode, cmd) error.logput = output raise error return output subprocess.check_output = check_output # Create directory and copy files if not os.path.isdir('/home/scratch/'): os.system('mkdir /home/scratch/ \ && cp /tmp/*.py /home/scratch/ \ && cp /tmp/*.log /home/scratch/ \ && cp /tmp/*.conf /home/scratch/') out_file = '/home/scratch/omsresults.log' open_file = open(out_file, 'w+') def main(): # Determine the operation being executed vm_supported, vm_dist, vm_ver = is_vm_supported_for_extension() linux_detect_installer() if len(sys.argv) == 2: option = sys.argv[1] if re.match('^([-/]*)(preinstall)', option): install_additional_packages() elif re.match('^([-/]*)(postinstall)', option): detect_workspace_id() config_start_oms_services() restart_services() result_commands() service_control_commands() write_html() dist_status() elif re.match('^([-/]*)(status)', option): result_commands() service_control_commands() write_html() dist_status() elif re.match('^([-/]*)(injectlogs)', option): time.sleep(120) inject_logs() elif re.match('^([-/]*)(copyomslogs)', option): detect_workspace_id() copy_oms_logs() elif re.match('^([-/]*)(copyextlogs)', option): copy_extension_logs() else: print("No operation specified. run with 'preinstall' or 'postinstall' or 'status' or 'copyextlogs'") def is_vm_supported_for_extension(): global vm_supported, vm_dist, vm_ver supported_dists = {'redhat' : ['6', '7'], # CentOS 'centos' : ['6', '7'], # CentOS 'red hat' : ['6', '7'], # Oracle, RHEL 'oracle' : ['6', '7'], # Oracle 'debian' : ['8', '9'], # Debian 'ubuntu' : ['14.04', '16.04', '18.04'], # Ubuntu 'suse' : ['12'], 'sles' : ['15']} # SLES try: vm_dist, vm_ver, vm_id = platform.linux_distribution() except AttributeError: vm_dist, vm_ver, vm_id = platform.dist() if not vm_dist and not vm_ver: # SLES 15 with open('/etc/os-release', 'r') as fp: for line in fp: if line.startswith('ID='): vm_dist = line.split('=')[1] vm_dist = vm_dist.split('-')[0] vm_dist = vm_dist.replace('\"', '').replace('\n', '') elif line.startswith('VERSION_ID='): vm_ver = line.split('=')[1] vm_ver = vm_ver.split('.')[0] vm_ver = vm_ver.replace('\"', '').replace('\n', '') vm_supported = False # Find this VM distribution in the supported list for supported_dist in supported_dists.keys(): if not vm_dist.lower().startswith(supported_dist): continue # Check if this VM distribution version is supported vm_ver_split = vm_ver.split('.') for supported_ver in supported_dists[supported_dist]: supported_ver_split = supported_ver.split('.') vm_ver_match = True for idx, supported_ver_num in enumerate(supported_ver_split): try: supported_ver_num = int(supported_ver_num) vm_ver_num = int(vm_ver_split[idx]) except IndexError: vm_ver_match = False break if vm_ver_num is not supported_ver_num: vm_ver_match = False break if vm_ver_match: vm_supported = True break if vm_supported: break return vm_supported, vm_dist, vm_ver def replace_items(infile, old_word, new_word): """Replace old_word with new_world in file infile.""" if not os.path.isfile(infile): print("Error on replace_word, not a regular file: " + infile) sys.exit(1) f1 = open(infile, 'r').read() f2 = open(infile, 'w') m = f1.replace(old_word, new_word) f2.write(m) def detect_workspace_id(): """Detect the workspace id where the agent is onboarded.""" global workspace_id x = subprocess.check_output('/opt/microsoft/omsagent/bin/omsadmin.sh -l', shell=True) try: workspace_id = re.search('[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}', x).group(0) except AttributeError: workspace_id = None def linux_detect_installer(): """Check what installer (dpkg or rpm) should be used.""" global INSTALLER INSTALLER = None if vm_supported and (vm_dist.startswith('Ubuntu') or vm_dist.startswith('debian')): INSTALLER = 'APT' elif vm_supported and (vm_dist.startswith('CentOS') or vm_dist.startswith('Oracle') or vm_dist.startswith('Red Hat')): INSTALLER = 'YUM' elif vm_supported and vm_dist.startswith('SUSE Linux'): INSTALLER = 'ZYPPER' def install_additional_packages(): """Install additional packages command here.""" if INSTALLER == 'APT': os.system('apt-get -y update && apt-get -y install wget apache2 git \ && service apache2 start') elif INSTALLER == 'YUM': os.system('yum update -y && yum install -y wget httpd git \ && service httpd start') elif INSTALLER == 'ZYPPER': os.system('zypper update -y && zypper install -y wget httpd git \ && service apache2 start') def enable_dsc(): """Enable DSC""" os.system('/opt/microsoft/omsconfig/Scripts/OMS_MetaConfigHelper.py --enable') def disable_dsc(): """Disable DSC""" os.system('/opt/microsoft/omsconfig/Scripts/OMS_MetaConfigHelper.py --disable') Pending_mof = '/etc/opt/omi/conf/omsconfig/configuration/Pending.mof' Current_mof = '/etc/opt/omi/conf/omsconfig/configuration/Pending.mof' if os.path.isfile(Pending_mof) or os.path.isfile(Current_mof): os.remove(Pending_mof) os.remove(Current_mof) def copy_config_files(): """Convert, copy, and set permissions for agent configuration files.""" os.system('cat /home/scratch/perf.conf >> /etc/opt/microsoft/omsagent/{0}/conf/omsagent.conf \ && cp /home/scratch/rsyslog-oms.conf /etc/opt/omi/conf/omsconfig/rsyslog-oms.conf \ && cp /home/scratch/rsyslog-oms.conf /etc/rsyslog.d/95-omsagent.conf \ && chown omsagent:omiusers /etc/rsyslog.d/95-omsagent.conf \ && chmod 644 /etc/rsyslog.d/95-omsagent.conf \ && cp /home/scratch/customlog.conf /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/customlog.conf \ && chown omsagent:omiusers /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/customlog.conf \ && cp /etc/opt/microsoft/omsagent/sysconf/omsagent.d/apache_logs.conf /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/apache_logs.conf \ && chown omsagent:omiusers /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/apache_logs.conf \ && cp /etc/opt/microsoft/omsagent/sysconf/omsagent.d/mysql_logs.conf /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/mysql_logs.conf \ && chown omsagent:omiusers /etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/mysql_logs.conf'.format(workspace_id)) replace_items('/etc/opt/microsoft/omsagent/{0}/conf/omsagent.conf'.format(workspace_id), '<workspace-id>', workspace_id) replace_items('/etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/customlog.conf'.format(workspace_id), '<workspace-id>', workspace_id) def apache_mysql_conf(): """Configure Apache and MySQL, set up empty log files, and add permissions.""" apache_conf_file = '/etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/apache_logs.conf'.format(workspace_id) mysql_conf_file = '/etc/opt/microsoft/omsagent/{0}/conf/omsagent.d/mysql_logs.conf'.format(workspace_id) apache_access_conf_path_string = '/usr/local/apache2/logs/access_log /var/log/apache2/access.log /var/log/httpd/access_log /var/log/apache2/access_log' apache_error_conf_path_string = '/usr/local/apache2/logs/error_log /var/log/apache2/error.log /var/log/httpd/error_log /var/log/apache2/error_log' os.system('chown omsagent:omiusers {0}'.format(apache_conf_file)) os.system('chown omsagent:omiusers {0}'.format(mysql_conf_file)) os.system('mkdir -p /var/log/mysql \ && touch /var/log/mysql/mysql.log /var/log/mysql/error.log /var/log/mysql/mysql-slow.log \ && touch /var/log/custom.log \ && chmod +r /var/log/mysql/* \ && chmod +rx /var/log/mysql \ && chmod +r /var/log/custom.log') if INSTALLER == 'APT': replace_items(apache_conf_file, apache_access_conf_path_string, '/var/log/apache2/access.log') replace_items(apache_conf_file, apache_error_conf_path_string, '/var/log/apache2/error.log') os.system('mkdir -p /var/log/apache2 \ && touch /var/log/apache2/access.log /var/log/apache2/error.log \ && chmod +r /var/log/apache2/* \ && chmod +rx /var/log/apache2') elif INSTALLER == 'YUM': replace_items(apache_conf_file, apache_access_conf_path_string, '/var/log/httpd/access_log') replace_items(apache_conf_file, apache_error_conf_path_string, '/var/log/httpd/error_log') os.system('mkdir -p /var/log/httpd \ && touch /var/log/httpd/access_log /var/log/httpd/error_log \ && chmod +r /var/log/httpd/* \ && chmod +rx /var/log/httpd') elif INSTALLER == 'ZYPPER': replace_items(apache_conf_file, apache_access_conf_path_string, '/var/log/apache2/access_log') replace_items(apache_conf_file, apache_error_conf_path_string, '/var/log/apache2/error_log') os.system('mkdir -p /var/log/apache2 \ && touch /var/log/apache2/access_log /var/log/apache2/error_log \ && chmod +r /var/log/apache2/* \ && chmod +rx /var/log/apache2') def inject_logs(): """Inject logs (after) agent is running in order to simulate real Apache/MySQL/Custom logs output.""" # set apache timestamps to current time to ensure they are searchable with 1 hour period in log analytics now = datetime.datetime.utcnow().strftime('[%d/%b/%Y:%H:%M:%S +0000]') os.system(r"sed -i 's|\(\[.*\]\)|{0}|' /home/scratch/apache_access.log".format(now)) if INSTALLER == 'APT': os.system('cat /home/scratch/apache_access.log >> /var/log/apache2/access.log \ && chown root:root /var/log/apache2/access.log \ && chmod 644 /var/log/apache2/access.log') elif INSTALLER == 'YUM': os.system('cat /home/scratch/apache_access.log >> /var/log/httpd/access_log \ && chown root:root /var/log/httpd/access_log \ && chmod 644 /var/log/httpd/access_log') elif INSTALLER == 'ZYPPER': os.system('cat /home/scratch/apache_access.log >> /var/log/apache2/access_log \ && chown root:root /var/log/apache2/access_log \ && chmod 644 /var/log/apache2/access_log') os.system('cat /home/scratch/mysql.log >> /var/log/mysql/mysql.log \ && cat /home/scratch/error.log >> /var/log/mysql/error.log \ && cat /home/scratch/mysql-slow.log >> /var/log/mysql/mysql-slow.log \ && cat /home/scratch/custom.log >> /var/log/custom.log') def config_start_oms_services(): """Orchestrate overall configuration prior to agent start.""" os.system('/opt/omi/bin/omiserver -d') disable_dsc() copy_config_files() apache_mysql_conf() def restart_services(): """Restart rsyslog, OMI, and OMS.""" time.sleep(10) os.system('service rsyslog restart \ && /opt/omi/bin/service_control restart \ && /opt/microsoft/omsagent/bin/service_control restart') def append_file(filename, destFile): f = open(filename, 'r') destFile.write(f.read()) f.close() def exec_command(cmd): """Run the provided command, check, and return its output.""" try: out = subprocess.check_output(cmd, shell=True) return out except subprocess.CalledProcessError as e: print(e.returncode) return e.returncode def write_log_output(log, out): """Save command output to the log file.""" if(type(out) != str): out = str(out) log.write(out + '\n') log.write('-' * 80) log.write('\n') def write_log_command(log, cmd): """Print command and save command to log file.""" print(cmd) log.write(cmd + '\n') log.write('=' * 40) log.write('\n') def check_pkg_status(pkg): """Check pkg install status and return output and derived status.""" if INSTALLER == 'APT': cmd = 'dpkg -s {0}'.format(pkg) output = exec_command(cmd) if (os.system('{0} | grep deinstall > /dev/null 2>&1'.format(cmd)) == 0 or os.system('dpkg -s omsagent > /dev/null 2>&1') != 0): status = 'Not Installed' else: status = 'Install Ok' elif INSTALLER == 'YUM' or INSTALLER == 'ZYPPER': cmd = 'rpm -qi {0}'.format(pkg) output = exec_command(cmd) if os.system('{0} > /dev/null 2>&1'.format(cmd)) == 0: status = 'Install Ok' else: status = 'Not Installed' write_log_command(open_file, cmd) write_log_output(open_file, output) return (output, status) def result_commands(): """Determine and store status of agent.""" global waagentOut, onboardStatus, omiRunStatus, psefomsagent, omsagentRestart, omiRestart global omiInstallOut, omsagentInstallOut, omsconfigInstallOut, scxInstallOut, omiInstallStatus, omsagentInstallStatus, omsconfigInstallStatus, scxInstallStatus cmd = 'waagent --version' waagentOut = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, waagentOut) cmd = '/opt/microsoft/omsagent/bin/omsadmin.sh -l' onboardStatus = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, onboardStatus) cmd = 'scxadmin -status' omiRunStatus = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, omiRunStatus) omiInstallOut, omiInstallStatus = check_pkg_status('omi') omsagentInstallOut, omsagentInstallStatus = check_pkg_status('omsagent') omsconfigInstallOut, omsconfigInstallStatus = check_pkg_status('omsconfig') scxInstallOut, scxInstallStatus = check_pkg_status('scx') # OMS agent process check cmd = 'ps -ef | egrep "omsagent|omi"' psefomsagent = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, psefomsagent) time.sleep(10) # OMS agent restart cmd = '/opt/microsoft/omsagent/bin/service_control restart' omsagentRestart = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, omsagentRestart) # OMI agent restart cmd = '/opt/omi/bin/service_control restart' omiRestart = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, omiRestart) def service_control_commands(): """Determine and store results of various service commands.""" global serviceStop, serviceDisable, serviceEnable, serviceStart # OMS stop (shutdown the agent) cmd = '/opt/microsoft/omsagent/bin/service_control stop' serviceStop = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, serviceStop) # OMS disable (disable agent from starting upon system start) cmd = '/opt/microsoft/omsagent/bin/service_control disable' serviceDisable = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, serviceDisable) # OMS enable (enable agent to start upon system start) cmd = '/opt/microsoft/omsagent/bin/service_control enable' serviceEnable = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, serviceEnable) # OMS start (start the agent) cmd = '/opt/microsoft/omsagent/bin/service_control start' serviceStart = exec_command(cmd) write_log_command(open_file, cmd) write_log_output(open_file, serviceStart) def write_html(): """Use stored command results to create an HTML report of the test results.""" os.system('rm /home/scratch/omsresults.html') html_file = '/home/scratch/omsresults.html' f = open(html_file, 'w+') message = """ <div class="text" style="white-space: pre-wrap" > <table> <caption><h4>OMS Install Results</h4><caption> <tr> <th>Package</th> <th>Status</th> <th>Output</th> </tr> <tr> <td>OMI</td> <td>{0}</td> <td>{1}</td> </tr> <tr> <td>OMSAgent</td> <td>{2}</td> <td>{3}</td> </tr> <tr> <td>OMSConfig</td> <td>{4}</td> <td>{5}</td> </tr> <tr> <td>SCX</td> <td>{6}</td> <td>{7}</td> </tr> </table> <table> <caption><h4>OMS Command Outputs</h4><caption> <tr> <th>Command</th> <th>Output</th> </tr> <tr> <td>waagent --version</td> <td>{8}</td> </tr> <tr> <td>/opt/microsoft/omsagent/bin/omsadmin.sh -l</td> <td>{9}</td> </tr> <tr> <td>scxadmin -status</td> <td>{10}</td> </tr> <tr> <td>ps -ef | egrep "omsagent|omi"</td> <td>{11}</td> </tr> <tr> <td>/opt/microsoft/omsagent/bin/service_control restart</td> <td>{12}</td> <tr> <tr> <td>/opt/omi/bin/service_control restart</td> <td>{13}</td> <tr> <tr> <td>/opt/microsoft/omsagent/bin/service_control stop</td> <td>{14}</td> <tr> <tr> <td>/opt/microsoft/omsagent/bin/service_control disable</td> <td>{15}</td> <tr> <tr> <td>/opt/microsoft/omsagent/bin/service_control enable</td> <td>{16}</td> <tr> <tr> <td>/opt/microsoft/omsagent/bin/service_control stop</td> <td>{17}</td> <tr> </table> </div> """.format(omiInstallStatus, omiInstallOut, omsagentInstallStatus, omsagentInstallOut, omsconfigInstallStatus, omsconfigInstallOut, scxInstallStatus, scxInstallOut, waagentOut, onboardStatus, omiRunStatus, psefomsagent, omsagentRestart, omiRestart, serviceStop, serviceDisable, serviceEnable, serviceStart) f.write(message) f.close() def dist_status(): f = open('/home/scratch/omsresults.status', 'w+') if os.system('/opt/microsoft/omsagent/bin/omsadmin.sh -l') == 0: detect_workspace_id() x_out = subprocess.check_output('/opt/microsoft/omsagent/bin/omsadmin.sh -l', shell=True) if x_out.rstrip() == "No Workspace": status_message = "Onboarding Failed" elif re.search('[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}', x_out).group(0) == workspace_id: status_message = "Agent Found" else: status_message = "Agent Not Found" f.write(status_message) f.close() def sorted_dir(folder): def getmtime(name): path = os.path.join(folder, name) return os.path.getmtime(path) return sorted(os.listdir(folder), key=getmtime, reverse=True) def copy_oms_logs(): omslogfile = "" split_name = vm_dist.split(' ') split_ver = vm_ver.split('.') if vm_dist.startswith('Red Hat'): omslogfile = '/home/scratch/{0}-omsagent.log'.format((split_name[0]+split_name[1]).lower() + split_ver[0]) else: omslogfile = '/home/scratch/{0}-omsagent.log'.format(split_name[0].lower() + split_ver[0]) omslogfileOpen = open(omslogfile, 'a+') omsagent_file = '/var/opt/microsoft/omsagent/{0}/log/omsagent.log'.format(workspace_id) write_log_command(omslogfileOpen, '\n OmsAgent Logs:\n') append_file(omsagent_file, omslogfileOpen) def copy_extension_logs(): extlogfile = "" split_name = vm_dist.split(' ') split_ver = vm_ver.split('.') if vm_dist.startswith('Red Hat'): extlogfile = '/home/scratch/{0}-extnwatcher.log'.format((split_name[0]+split_name[1]).lower() + split_ver[0]) else: extlogfile = '/home/scratch/{0}-extnwatcher.log'.format(split_name[0].lower() + split_ver[0]) extlogfileOpen = open(extlogfile, 'a+') oms_azure_ext_dir = '/var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/' ext_contents = sorted_dir(oms_azure_ext_dir) if ext_contents[0].startswith('extension') or ext_contents[0].startswith('watcher'): write_log_command(extlogfileOpen, '\n Extension Logs:\n') append_file(oms_azure_ext_dir + '/extension.log', extlogfileOpen) write_log_command(extlogfileOpen, '\n Watcher Logs:\n') append_file(oms_azure_ext_dir + '/watcher.log', extlogfileOpen) else: write_log_command(extlogfileOpen, '\n Extension Logs:\n') append_file(oms_azure_ext_dir + ext_contents[0] + '/extension.log', extlogfileOpen) write_log_command(extlogfileOpen, '\n Watcher Logs:\n') append_file(oms_azure_ext_dir + ext_contents[0] + '/watcher.log', extlogfileOpen) extlogfileOpen.close() if __name__ == '__main__': main() ================================================ FILE: OmsAgent/extension-test/omsfiles/perf.conf ================================================ <source> type oms_omi object_name "Logical Disk" instance_regex ".*" counter_name_regex "(% Used Inodes|Free Megabytes|% Used Space|Disk Transfers/sec|Disk Reads/sec|Disk Writes/sec)" interval 10s omi_mapping_path /etc/opt/microsoft/omsagent/<workspace-id>/conf/omsagent.d/omi_mapping.json </source> <source> type oms_omi object_name "Processor" instance_regex ".*" counter_name_regex "(% Processor Time|% Privileged Time)" interval 10s omi_mapping_path /etc/opt/microsoft/omsagent/<workspace-id>/conf/omsagent.d/omi_mapping.json </source> <source> type oms_omi object_name "Network" instance_regex ".*" counter_name_regex "(Total Bytes Transmitted|Total Bytes Received)" interval 10s omi_mapping_path /etc/opt/microsoft/omsagent/<workspace-id>/conf/omsagent.d/omi_mapping.json </source> <source> type oms_omi object_name "Memory" instance_regex ".*" counter_name_regex "(Available MBytes Memory|% Used Memory|% Used Swap Space)" interval 10s omi_mapping_path /etc/opt/microsoft/omsagent/<workspace-id>/conf/omsagent.d/omi_mapping.json </source> ================================================ FILE: OmsAgent/extension-test/omsfiles/rsyslog-oms.conf ================================================ # OMS Syslog collection for workspace a0d166ba-98b9-402e-b805-172ed62150a4 daemon.=alert;daemon.=crit;daemon.=debug;daemon.=emerg;daemon.=err;daemon.=info;daemon.=notice;daemon.=warning @127.0.0.1:25224 kern.=alert;kern.=crit;kern.=debug;kern.=emerg;kern.=err;kern.=info;kern.=notice;kern.=warning @127.0.0.1:25224 syslog.=alert;syslog.=crit;syslog.=debug;syslog.=emerg;syslog.=err;syslog.=info;syslog.=notice;syslog.=warning @127.0.0.1:25224 cron.=alert;cron.=crit;cron.=debug;cron.=emerg;cron.=err;cron.=info;cron.=notice;cron.=warning @127.0.0.1:25224 ================================================ FILE: OmsAgent/extension-test/parameters.json ================================================ { "resource": "https://management.azure.com", "authority host url": "https://login.microsoftonline.com", "resource group": "<resource-group-name>", "location": "<location>", "username": "<username>", "ssh private": "<ssh-private-keyfile-path>", "nsg resource group": "<nsg-resource-group>", "nsg": "<nsg>", "size": "<size>", "workspace": "<workspace-name>", "key vault": "<key-vault-name>", "old version": "<old-extesion-version>" } ================================================ FILE: OmsAgent/extension-test/verify_e2e.py ================================================ '''Verify end-to-end data transmission.''' import json import os import re import sys import subprocess import adal import requests ENDPOINT = ('https://management.azure.com/subscriptions/{}/resourcegroups/' '{}/providers/Microsoft.OperationalInsights/workspaces/{}/api/' 'query?api-version=2017-01-01-preview') def check_e2e(hostname, timespan = 'PT30M'): ''' Verify data from computer with provided hostname is present in the Log Analytics workspace specified in parameters.json, append results to e2eresults.json ''' global success_count global success_sources global failed_sources success_count = 0 failed_sources = [] success_sources = [] with open('{0}/parameters.json'.format(os.getcwd()), 'r') as f: parameters = f.read() if re.search(r'"<.*>"', parameters): print('Please replace placeholders in parameters.json') exit() parameters = json.loads(parameters) key_vault = parameters['key vault'] tenant_id = str(json.loads(subprocess.check_output('az keyvault secret show --name tenant-id --vault-name {0}'.format(key_vault), shell=True))["value"]) app_id = str(json.loads(subprocess.check_output('az keyvault secret show --name app-id --vault-name {0}'.format(key_vault), shell=True))["value"]) app_secret = str(json.loads(subprocess.check_output('az keyvault secret show --name app-secret --vault-name {0}'.format(key_vault), shell=True))["value"]) authority_url = parameters['authority host url'] + '/' + tenant_id context = adal.AuthenticationContext(authority_url) token = context.acquire_token_with_client_credentials( parameters['resource'], app_id, app_secret) head = {'Authorization': 'Bearer ' + token['accessToken']} subscription = str(json.loads(subprocess.check_output('az keyvault secret show --name subscription-id --vault-name {0}'.format(key_vault), shell=True))["value"]) resource_group = parameters['resource group'] workspace = parameters['workspace'] url = ENDPOINT.format(subscription, resource_group, workspace) sources = ['Heartbeat', 'Syslog', 'Perf', 'ApacheAccess_CL', 'MySQL_CL', 'Custom_Log_CL'] distro = hostname.split('-')[0] results = {} results[distro] = {} print('Verifying data from computer {}'.format(hostname)) for s in sources: query = '%s | where Computer == \'%s\' | take 1' % (s, hostname) r = requests.post(url, headers=head, json={'query':query, 'timespan':timespan}) if r.status_code == requests.codes.ok: r = (json.loads(r.text)['Tables'])[0] if len(r['Rows']) < 1: results[distro][s] = 'Failure: no logs' failed_sources.append(s) else: results[distro][s] = 'Success' success_count += 1 success_sources.append(s) else: results[distro][s] = 'Failure: {} {}'.format(r.status_code, r.text) results[distro] = [results[distro]] print(results) return results def main(): '''Check for data with given hostname.''' if len(sys.argv) == 2: check_e2e(sys.argv[1]) else: print('Hostname not provided') exit() if __name__ == '__main__': main() ================================================ FILE: OmsAgent/keys/dscgpgkey.asc ================================================ -----BEGIN PGP PUBLIC KEY BLOCK----- Version: GnuPG v1.4.7 (GNU/Linux) mQENBFcDALYBCADAKoZhZlJxGNGWzqV+1OG1xiQeoowKhssGAKvd+buXCGISZJwT LXZqIcIiLP7pqdcZWtE9bSc7yBY2MalDp9Liu0KekywQ6VVX1T72NPf5Ev6x6DLV 7aVWsCzUAF+eb7DC9fPuFLEdxmOEYoPjzrQ7cCnSV4JQxAqhU4T6OjbvRazGl3ag OeizPXmRljMtUUttHQZnRhtlzkmwIrUivbfFPD+fEoHJ1+uIdfOzZX8/oKHKLe2j H632kvsNzJFlROVvGLYAk2WRcLu+RjjggixhwiB+Mu/A8Tf4V6b+YppS44q8EvVr M+QvY7LNSOffSO6Slsy9oisGTdfE39nC7pVRABEBAAG0NU1pY3Jvc29mdCAoUmVs ZWFzZSBTaWduaW5nKSA8ZHNjZ3Bna2V5QG1pY3Jvc29mdC5jb20+iQE1BBMBAgAf BQJXAwC2AhsDBgsJCAcDAgQVAggDAxYCAQIeAQIXgAAKCRAgVBo93jISlLZYB/44 DIa5AX9csM1N0+kddBHb23NSRkEFMlD+rTjiTk/Nsrh8RghPlHlXEd/Rpxf2c+xJ TjPrpdL0dHzou5ZEdTVtCeVCV0YA2cZk+RfhthHnX5M1m0suu5HgSEHfKyqlfJwZ uYapagLoE4jXbQnw9UJgdSpa8OFjOcyZ9oNCn9IHG3W7JAV1+upUBKM/iwHTuVrQ yrbYBlqVRWi4s3nDpqEZMBSq1KJucHIt2uOqAlz9hRUXjWNsD+Ff+Nn1EvkDdzn5 KrRUgA9bSp6FPBEluIO/QFA6aTW4MrujCHCrpiDPxFGg7WTOXS8tg5AJ/d/l/pOp 5/E3CO1YTCgEMl34eOdU =JQx7 -----END PGP PUBLIC KEY BLOCK----- ================================================ FILE: OmsAgent/keys/msgpgkey.asc ================================================ -----BEGIN PGP PUBLIC KEY BLOCK----- Version: GnuPG v1.4.7 (GNU/Linux) mQENBFcDBSwBCADAKoZhZlJxGNGWzqV+1OG1xiQeoowKhssGAKvd+buXCGISZJwT LXZqIcIiLP7pqdcZWtE9bSc7yBY2MalDp9Liu0KekywQ6VVX1T72NPf5Ev6x6DLV 7aVWsCzUAF+eb7DC9fPuFLEdxmOEYoPjzrQ7cCnSV4JQxAqhU4T6OjbvRazGl3ag OeizPXmRljMtUUttHQZnRhtlzkmwIrUivbfFPD+fEoHJ1+uIdfOzZX8/oKHKLe2j H632kvsNzJFlROVvGLYAk2WRcLu+RjjggixhwiB+Mu/A8Tf4V6b+YppS44q8EvVr M+QvY7LNSOffSO6Slsy9oisGTdfE39nC7pVRABEBAAG0NE1pY3Jvc29mdCAoUmVs ZWFzZSBTaWduaW5nKSA8bXNncGdrZXlAbWljcm9zb2Z0LmNvbT6JATUEEwECAB8F AlcDBSwCGwMGCwkIBwMCBBUCCAMDFgIBAh4BAheAAAoJEMTsSeVEvEF40uoIAJdJ yxhQLo/VntUHUrTita63CbUCDw1AAb3ltgXPIfSSnhotEb8KQrJjghu8XO3/Swre geB6DuYm77tUIHOoA3SiOXi67EfhwM1iaRDzorf+U/59R0evQ57IWrA/g4Ceh0CJ picFwLUe0BKKVgxtTvOxPa08P1znA5IVWR6fruqHyy9TbYYSYYV7B+Cw3KS+JCzw fV/nH0F9slgxgcwhzezk1b0glGfCuiswnK7nHxHYW7B+vjfRd+Seq8lM1CYozbe5 6TPbfgyisiEsZDulEU0jpGa2q1UwnKaP1A7mgTxRgLmmg/EzC3MTzvSqvQI6Xvme nHX/CNyXbumiyqsH3Tw= =yTH1 -----END PGP PUBLIC KEY BLOCK----- ================================================ FILE: OmsAgent/manifest.xml ================================================ <?xml version='1.0' encoding='utf-8' ?> <ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure"> <ProviderNameSpace>Microsoft.EnterpriseCloud.Monitoring</ProviderNameSpace> <Type>OmsAgentForLinux</Type> <Version>1.13.19</Version> <Label>Microsoft Operations Management Suite Agent for Linux</Label> <HostingResources>VmRole</HostingResources> <MediaLink></MediaLink> <Description>Microsoft Operations Management Suite Agent for Linux</Description> <IsInternalExtension>true</IsInternalExtension> <Eula>https://github.com/Microsoft/OMS-Agent-for-Linux/blob/master/LICENSE</Eula> <PrivacyUri>http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx</PrivacyUri> <HomepageUri>https://github.com/Microsoft/OMS-Agent-for-Linux</HomepageUri> <IsJsonExtension>true</IsJsonExtension> <SupportedOS>Linux</SupportedOS> <CompanyName>Microsoft</CompanyName> <!--%REGIONS%--> </ExtensionImage> ================================================ FILE: OmsAgent/omsagent.py ================================================ #!/usr/bin/env python # # OmsAgentForLinux Extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import sys import os import os.path import signal import pwd import grp import re import traceback import time import platform import subprocess import json import base64 import inspect import watcherutil import shutil from threading import Thread try: from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as HUtil except Exception as e: # These utils have checks around the use of them; this is not an exit case print('Importing utils failed with error: {0}'.format(e)) if sys.version_info[0] == 3: import urllib.request as urllib from urllib.parse import urlparse import urllib.error as urlerror elif sys.version_info[0] == 2: import urllib2 as urllib from urlparse import urlparse import urllib2 as urlerror # This monkey patch duplicates the one made in the waagent import above. # It is necessary because on 2.6, the waagent monkey patch appears to be overridden # by the python-future subprocess.check_output backport. if sys.version_info < (2,7): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output = check_output subprocess.CalledProcessError = CalledProcessError # Global Variables ProceedOnSigningVerificationFailure = True PackagesDirectory = 'packages' keysDirectory = 'keys' # Below file version will be replaced during OMS-Build time. BundleFileName = 'omsagent-0.0.0-0.universal.x64.sh' GUIDRegex = r'[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}' GUIDOnlyRegex = r'^' + GUIDRegex + '$' SCOMCertIssuerRegex = r'^[\s]*Issuer:[\s]*CN=SCX-Certificate/title=SCX' + GUIDRegex + ', DC=.*$' SCOMPort = 1270 PostOnboardingSleepSeconds = 5 InitialRetrySleepSeconds = 30 IsUpgrade = False # Paths OMSAdminPath = '/opt/microsoft/omsagent/bin/omsadmin.sh' OMSAgentServiceScript = '/opt/microsoft/omsagent/bin/service_control' OMIConfigEditorPath = '/opt/omi/bin/omiconfigeditor' OMIServerConfPath = '/etc/opt/omi/conf/omiserver.conf' EtcOMSAgentPath = '/etc/opt/microsoft/omsagent/' VarOMSAgentPath = '/var/opt/microsoft/omsagent/' SCOMCertPath = '/etc/opt/microsoft/scx/ssl/scx.pem' ExtensionStateSubdirectory = 'state' # Commands # Always use upgrade - will handle install if scx, omi are not installed or upgrade if they are. InstallCommandTemplate = '{0} --upgrade {1}' UninstallCommandTemplate = '{0} --remove' WorkspaceCheckCommand = '{0} -l'.format(OMSAdminPath) OnboardCommandWithOptionalParams = '{0} -w {1} -s {2} {3}' RestartOMSAgentServiceCommand = '{0} restart'.format(OMSAgentServiceScript) DisableOMSAgentServiceCommand = '{0} disable'.format(OMSAgentServiceScript) InstallExtraPackageCommandApt = 'apt-get -y update && apt-get -y install {0}' SkipDigestCmdTemplate = '{0} --noDigest' # Cloud Environments PublicCloudName = "AzurePublicCloud" FairfaxCloudName = "AzureUSGovernmentCloud" MooncakeCloudName = "AzureChinaCloud" USNatCloudName = "USNat" # EX USSecCloudName = "USSec" # RX DefaultCloudName = PublicCloudName # Fallback CloudDomainMap = { PublicCloudName: "opinsights.azure.com", FairfaxCloudName: "opinsights.azure.us", MooncakeCloudName: "opinsights.azure.cn", USNatCloudName: "opinsights.azure.eaglex.ic.gov", USSecCloudName: "opinsights.azure.microsoft.scloud" } # Error codes DPKGLockedErrorCode = 55 #56, temporary as it excludes from SLA InstallErrorCurlNotInstalled = 55 #64, temporary as it excludes from SLA EnableErrorOMSReturned403 = 5 EnableErrorOMSReturnedNon200 = 6 EnableErrorResolvingHost = 7 EnableErrorOnboarding = 8 EnableCalledBeforeSuccessfulInstall = 52 # since install is a missing dependency UnsupportedOpenSSL = 55 #60, temporary as it excludes from SLA UnsupportedGpg = 55 # OneClick error codes OneClickErrorCode = 40 ManagedIdentityExtMissingErrorCode = 41 ManagedIdentityExtErrorCode = 42 MetadataAPIErrorCode = 43 OMSServiceOneClickErrorCode = 44 MissingorInvalidParameterErrorCode = 11 UnwantedMultipleConnectionsErrorCode = 10 CannotConnectToOMSErrorCode = 55 UnsupportedOperatingSystem = 51 # Configuration HUtilObject = None SettingsSequenceNumber = None HandlerEnvironment = None SettingsDict = None # OneClick Constants ManagedIdentityExtListeningURLPath = '/var/lib/waagent/ManagedIdentity-Settings' GUIDRegex = '[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}' OAuthTokenResource = 'https://management.core.windows.net/' OMSServiceValidationEndpoint = 'https://global.oms.opinsights.azure.com/ManagedIdentityService.svc/Validate' AutoManagedWorkspaceCreationSleepSeconds = 20 # agent permissions AgentUser='omsagent' AgentGroup='omiusers' """ What need to be packaged to make the signing work: keys dscgpgkey.asc msgpgkey.asc packages omsagent-*.universal.x64.asc omsagent-*.universal.x64.sha256sums """ def verifyShellBundleSigningAndChecksum(): cert_directory = os.path.join(os.getcwd(), PackagesDirectory) keys_directory = os.path.join(os.getcwd(), keysDirectory) # import GPG key dscGPGKeyFilePath = os.path.join(keys_directory, 'dscgpgkey.asc') if not os.path.isfile(dscGPGKeyFilePath): raise Exception("Unable to find the dscgpgkey.asc file at " + dscGPGKeyFilePath) importGPGKeyCommand = "sh ImportGPGkey.sh " + dscGPGKeyFilePath exit_code, output = run_command_with_retries_output(importGPGKeyCommand, retries = 0, retry_check = retry_skip, check_error = False) # Check that we can find the keyring file keyringFilePath = os.path.join(keys_directory, 'keyring.gpg') if not os.path.isfile(keyringFilePath): raise Exception("Unable to find the Extension keyring file at " + keyringFilePath) # Check that we can find the asc file bundleFileName, file_ext = os.path.splitext(BundleFileName) ascFilePath = os.path.join(cert_directory, bundleFileName + ".asc") if not os.path.isfile(ascFilePath): raise Exception("Unable to find the OMS shell bundle asc file at " + ascFilePath) # check that we can find the SHA256 sums file sha256SumsFilePath = os.path.join(cert_directory, bundleFileName + ".sha256sums") if not os.path.isfile(sha256SumsFilePath): raise Exception("Unable to find the OMS shell bundle SHA256 sums file at " + sha256SumsFilePath) # Verify the SHA256 sums file with the keyring and asc files verifySha256SumsCommand = "HOME=" + keysDirectory + " gpg --no-default-keyring --keyring " + keyringFilePath + " --verify " + ascFilePath + " " + sha256SumsFilePath exit_code, output = run_command_with_retries_output(verifySha256SumsCommand, retries = 0, retry_check = retry_skip, check_error = False) if exit_code != 0: raise Exception("Failed to verify SHA256 sums file at " + sha256SumsFilePath) # Perform SHA256 sums to verify shell bundle hutil_log_info("Perform SHA256 sums to verify shell bundle") performSha256SumsCommand = "cd %s; sha256sum -c %s" % (cert_directory, sha256SumsFilePath) exit_code, output = run_command_with_retries_output(performSha256SumsCommand, retries = 0, retry_check = retry_skip, check_error = False) if exit_code != 0: raise Exception("Failed to verify shell bundle with the SHA256 sums file at " + sha256SumsFilePath) def main(): """ Main method Parse out operation from argument, invoke the operation, and finish. """ init_waagent_logger() waagent_log_info('OmsAgentForLinux started to handle.') global IsUpgrade # Determine the operation being executed operation = None try: option = sys.argv[1] if re.match('^([-/]*)(disable)', option): operation = 'Disable' elif re.match('^([-/]*)(uninstall)', option): operation = 'Uninstall' elif re.match('^([-/]*)(install)', option): operation = 'Install' elif re.match('^([-/]*)(enable)', option): operation = 'Enable' elif re.match('^([-/]*)(update)', option): operation = 'Update' IsUpgrade = True elif re.match('^([-/]*)(telemetry)', option): operation = 'Telemetry' except Exception as e: waagent_log_error(str(e)) if operation is None: log_and_exit('Unknown', 1, 'No valid operation provided') # Set up for exit code and any error messages exit_code = 0 message = '{0} succeeded'.format(operation) # Clean status file to mitigate diskspace issues on small VMs status_files = [ "/var/opt/microsoft/omsconfig/status/dscperformconsistency", "/var/opt/microsoft/omsconfig/status/dscperforminventory", "/var/opt/microsoft/omsconfig/status/dscsetlcm", "/var/opt/microsoft/omsconfig/status/omsconfighost" ] for sf in status_files: if os.path.isfile(sf): if sf.startswith("/var/opt/microsoft/omsconfig/status"): try: os.remove(sf) except Exception as e: hutil_log_info('Error removing telemetry status file before installation: {0}'.format(sf)) hutil_log_info('Exception info: {0}'.format(traceback.format_exc())) exit_code = check_disk_space_availability() if exit_code != 0: message = '{0} failed due to low disk space'.format(operation) log_and_exit(operation, exit_code, message) exit_if_gpg_unavailable(operation) # Invoke operation try: global HUtilObject HUtilObject = parse_context(operation) # Verify shell bundle signing try: hutil_log_info("Start signing verification") verifyShellBundleSigningAndChecksum() hutil_log_info("ShellBundle signing verification succeeded") except Exception as ex: errmsg = "ShellBundle signing verification failed with '%s'" % ex.message if ProceedOnSigningVerificationFailure: hutil_log_error(errmsg) else: log_and_exit(operation, errmsg) # invoke operation exit_code, output = operations[operation]() # Exit code 1 indicates a general problem that doesn't have a more # specific error code; it often indicates a missing dependency if exit_code == 1 and operation == 'Install': message = 'Install failed with exit code 1. Please check that ' \ 'dependencies are installed. For details, check logs ' \ 'in /var/log/azure/Microsoft.EnterpriseCloud.' \ 'Monitoring.OmsAgentForLinux' elif exit_code == 127 and operation == 'Install': # happens if shell bundle couldn't be extracted due to low space or missing dependency exit_code = 52 # since it is a missing dependency message = 'Install failed with exit code 127. Please check that ' \ 'dependencies are installed. For details, check logs ' \ 'in /var/log/azure/Microsoft.EnterpriseCloud.' \ 'Monitoring.OmsAgentForLinux' elif exit_code is DPKGLockedErrorCode and operation == 'Install': message = 'Install failed with exit code {0} because the ' \ 'package manager on the VM is currently locked: ' \ 'please wait and try again'.format(DPKGLockedErrorCode) elif exit_code != 0: message = '{0} failed with exit code {1} {2}'.format(operation, exit_code, output) except OmsAgentForLinuxException as e: exit_code = e.error_code message = e.get_error_message(operation) except Exception as e: exit_code = 1 message = '{0} failed with error: {1}\n' \ 'Stacktrace: {2}'.format(operation, e, traceback.format_exc()) # Finish up and log messages log_and_exit(operation, exit_code, message) def check_disk_space_availability(): """ Check if there is the required space on the machine. """ try: if get_free_space_mb("/var") < 500 or get_free_space_mb("/etc") < 500 or get_free_space_mb("/opt") < 500: # 52 is the exit code for missing dependency i.e. disk space # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr return 52 else: return 0 except: print('Failed to check disk usage.') return 0 def get_free_space_mb(dirname): """ Get the free space in MB in the directory path. """ st = os.statvfs(dirname) return (st.f_bavail * st.f_frsize) // (1024 * 1024) def stop_telemetry_process(): pids_filepath = os.path.join(os.getcwd(),'omstelemetry.pid') # kill existing telemetry watcher if os.path.exists(pids_filepath): with open(pids_filepath, "r") as f: for pid in f.readlines(): # Verify the pid actually belongs to omsagent. cmd_file = os.path.join("/proc", str(pid.strip("\n")), "cmdline") if os.path.exists(cmd_file): with open(cmd_file, "r") as pidf: cmdline = pidf.readlines() if cmdline[0].find("omsagent.py") >= 0 and cmdline[0].find("-telemetry") >= 0: kill_cmd = "kill " + pid run_command_and_log(kill_cmd) run_command_and_log("rm "+pids_filepath) def start_telemetry_process(): """ Start telemetry process that performs periodic monitoring activities :return: None """ stop_telemetry_process() #start telemetry watcher omsagent_filepath = os.path.join(os.getcwd(),'omsagent.py') args = ['python{0}'.format(sys.version_info[0]), omsagent_filepath, '-telemetry'] log = open(os.path.join(os.getcwd(), 'daemon.log'), 'w') hutil_log_info('start watcher process '+str(args)) subprocess.Popen(args, stdout=log, stderr=log) def telemetry(): pids_filepath = os.path.join(os.getcwd(), 'omstelemetry.pid') py_pid = os.getpid() with open(pids_filepath, 'w') as f: f.write(str(py_pid) + '\n') if HUtilObject is not None: watcher = watcherutil.Watcher(HUtilObject.error, HUtilObject.log) watcher_thread = Thread(target = watcher.watch) self_mon_thread = Thread(target = watcher.monitor_health) watcher_thread.start() self_mon_thread.start() watcher_thread.join() self_mon_thread.join() return 0, "" def prepare_update(): """ Copy / move configuration directory to the backup """ # First check if backup directory was previously created for given workspace. # If it is created with all the files , we need not move the files again. public_settings, _ = get_settings() workspaceId = public_settings.get('workspaceId') etc_remove_path = os.path.join(EtcOMSAgentPath, workspaceId) etc_move_path = os.path.join(EtcOMSAgentPath, ExtensionStateSubdirectory, workspaceId) if (not os.path.isdir(etc_move_path)): shutil.move(etc_remove_path, etc_move_path) return 0, "" def restore_state(workspaceId): """ Copy / move state from backup to the expected location. """ try: etc_backup_path = os.path.join(EtcOMSAgentPath, ExtensionStateSubdirectory, workspaceId) etc_final_path = os.path.join(EtcOMSAgentPath, workspaceId) if (os.path.isdir(etc_backup_path) and not os.path.isdir(etc_final_path)): shutil.move(etc_backup_path, etc_final_path) except Exception as e: hutil_log_error("Error while restoring the state. Exception : "+traceback.format_exc()) def install(): """ Ensure that this VM distro and version are supported. Install the OMSAgent shell bundle, using retries. Note: install operation times out from WAAgent at 15 minutes, so do not wait longer. """ exit_if_vm_not_supported('Install') public_settings, protected_settings = get_settings() if public_settings is None: raise ParameterMissingException('Public configuration must be ' \ 'provided') workspaceId = public_settings.get('workspaceId') check_workspace_id(workspaceId) # Take the backup of the state for given workspace. restore_state(workspaceId) # In the case where a SCOM connection is already present, we should not # create conflicts by installing the OMSAgent packages stopOnMultipleConnections = public_settings.get('stopOnMultipleConnections') if (stopOnMultipleConnections is not None and stopOnMultipleConnections is True): detect_multiple_connections(workspaceId) package_directory = os.path.join(os.getcwd(), PackagesDirectory) bundle_path = os.path.join(package_directory, BundleFileName) os.chmod(bundle_path, 100) skipDockerProviderInstall = public_settings.get( 'skipDockerProviderInstall') if (skipDockerProviderInstall is not None and skipDockerProviderInstall is True): cmd = InstallCommandTemplate.format( bundle_path, '--skip-docker-provider-install') else: cmd = InstallCommandTemplate.format(bundle_path, '') noDigest = public_settings.get( 'noDigest') if (noDigest is not None and noDigest is True): cmd = SkipDigestCmdTemplate.format(cmd) hutil_log_info('Running command "{0}"'.format(cmd)) # Retry, since install can fail due to concurrent package operations exit_code, output = run_command_with_retries_output(cmd, retries = 10, retry_check = retry_if_dpkg_locked_or_curl_is_not_found, final_check = final_check_if_dpkg_locked) return exit_code, output def check_kill_process(pstring): for line in os.popen("ps ax | grep " + pstring + " | grep -v grep"): fields = line.split() pid = fields[0] os.kill(int(pid), signal.SIGKILL) def uninstall(): """ Uninstall the OMSAgent shell bundle. This is a somewhat soft uninstall. It is not a purge. Note: uninstall operation times out from WAAgent at 5 minutes """ package_directory = os.path.join(os.getcwd(), PackagesDirectory) bundle_path = os.path.join(package_directory, BundleFileName) global IsUpgrade os.chmod(bundle_path, 100) cmd = UninstallCommandTemplate.format(bundle_path) hutil_log_info('Running command "{0}"'.format(cmd)) # Retry, since uninstall can fail due to concurrent package operations try: exit_code, output = run_command_with_retries_output(cmd, retries = 5, retry_check = retry_if_dpkg_locked_or_curl_is_not_found, final_check = final_check_if_dpkg_locked) except Exception as e: # try to force clean the installation try: check_kill_process("omsagent") exit_code = 0 except Exception as ex: exit_code = 1 message = 'Uninstall failed with error: {0}\n' \ 'Stacktrace: {1}'.format(ex, traceback.format_exc()) if IsUpgrade: IsUpgrade = False else: remove_workspace_configuration() return exit_code, output def enable(): """ Onboard the OMSAgent to the specified OMS workspace. This includes enabling the OMS process on the VM. This call will return non-zero or throw an exception if the settings provided are incomplete or incorrect. Note: enable operation times out from WAAgent at 5 minutes """ exit_if_vm_not_supported('Enable') public_settings, protected_settings = get_settings() if public_settings is None: raise ParameterMissingException('Public configuration must be ' \ 'provided') if protected_settings is None: raise ParameterMissingException('Private configuration must be ' \ 'provided') vmResourceId = protected_settings.get('vmResourceId') # If vmResourceId is not provided in private settings, get it from metadata API if vmResourceId is None or not vmResourceId: vmResourceId = get_vmresourceid_from_metadata() hutil_log_info('vmResourceId from Metadata API is {0}'.format(vmResourceId)) if vmResourceId is None: hutil_log_info('This may be a classic VM') enableAutomaticManagement = public_settings.get('enableAutomaticManagement') if (enableAutomaticManagement is not None and enableAutomaticManagement is True): hutil_log_info('enableAutomaticManagement is set to true; the ' \ 'workspace ID and key will be determined by the OMS ' \ 'service.') workspaceInfo = retrieve_managed_workspace(vmResourceId) if (workspaceInfo is None or 'WorkspaceId' not in workspaceInfo or 'WorkspaceKey' not in workspaceInfo): raise OneClickException('Workspace info was not determined') else: # Note: do NOT log workspace keys! hutil_log_info('Managed workspaceInfo has been retrieved') workspaceId = workspaceInfo['WorkspaceId'] workspaceKey = workspaceInfo['WorkspaceKey'] try: check_workspace_id_and_key(workspaceId, workspaceKey) except InvalidParameterError as e: raise OMSServiceOneClickException('Received invalid ' \ 'workspace info: ' \ '{0}'.format(e)) else: workspaceId = public_settings.get('workspaceId') workspaceKey = protected_settings.get('workspaceKey') check_workspace_id_and_key(workspaceId, workspaceKey) # Check if omsadmin script is available if not os.path.exists(OMSAdminPath): log_and_exit('Enable', EnableCalledBeforeSuccessfulInstall, 'OMSAgent onboarding script {0} does not exist. Enable ' \ 'cannot be called before install.'.format(OMSAdminPath)) vmResourceIdParam = '-a {0}'.format(vmResourceId) proxy = protected_settings.get('proxy') proxyParam = '' if proxy is not None: proxyParam = '-p {0}'.format(proxy) # get domain from protected settings domain = protected_settings.get('domain') if domain is None: # detect opinsights domain using IMDS domain = get_azure_cloud_domain() else: hutil_log_info("Domain retrieved from protected settings '{0}'".format(domain)) domainParam = '' if domain: domainParam = '-d {0}'.format(domain) optionalParams = '{0} {1} {2}'.format(domainParam, proxyParam, vmResourceIdParam) onboard_cmd = OnboardCommandWithOptionalParams.format(OMSAdminPath, workspaceId, workspaceKey, optionalParams) hutil_log_info('Handler initiating onboarding.') exit_code, output = run_command_with_retries_output(onboard_cmd, retries = 5, retry_check = retry_onboarding, final_check = raise_if_no_internet, check_error = True, log_cmd = False) # now ensure the permissions and ownership is set recursively try: workspaceId = public_settings.get('workspaceId') etc_final_path = os.path.join(EtcOMSAgentPath, workspaceId) if (os.path.isdir(etc_final_path)): uid = pwd.getpwnam(AgentUser).pw_uid gid = grp.getgrnam(AgentGroup).gr_gid os.chown(etc_final_path, uid, gid) os.system('chmod {1} {0}'.format(etc_final_path, 750)) for root, dirs, files in os.walk(etc_final_path): for d in dirs: os.chown(os.path.join(root, d), uid, gid) os.system('chmod {1} {0}'.format(os.path.join(root, d), 750)) for f in files: os.chown(os.path.join(root, f), uid, gid) os.system('chmod {1} {0}'.format(os.path.join(root, f), 640)) except: hutil_log_info('Failed to set permissions for OMS directories, could potentially have issues uploading.') if exit_code == 0: # Create a marker file to denote the workspace that was # onboarded using the extension. This will allow supporting # multi-homing through the extension like Windows does extension_marker_path = os.path.join(EtcOMSAgentPath, workspaceId, 'conf/.azure_extension_marker') if os.path.exists(extension_marker_path): hutil_log_info('Extension marker file {0} already ' \ 'created'.format(extension_marker_path)) else: try: open(extension_marker_path, 'w').close() hutil_log_info('Created extension marker file ' \ '{0}'.format(extension_marker_path)) except IOError as e: try: open(extension_marker_path, 'w+').close() hutil_log_info('Created extension marker file ' \ '{0}'.format(extension_marker_path)) except IOError as ex: hutil_log_error('Error creating {0} with error: ' \ '{1}'.format(extension_marker_path, ex)) # we are having some kind of permissions issue creating the marker file output = "Couldn't create marker file" exit_code = 52 # since it is a missing dependency # Sleep to prevent bombarding the processes, then restart all processes # to resolve any issues with auto-started processes from --upgrade time.sleep(PostOnboardingSleepSeconds) if HUtilObject and HUtilObject.is_seq_smaller(): log_output = "Current sequence number {0} is smaller than or egual to the sequence number of the most recent executed configuration, skipping omsagent process restart.".format(HUtilObject._context._seq_no) hutil_log_info(log_output) else: hutil_log_info('Restart omsagent service via service_control script.') run_command_and_log(RestartOMSAgentServiceCommand) #start telemetry process if enable is successful start_telemetry_process() #save sequence number HUtilObject.save_seq() return exit_code, output def remove_workspace_configuration(): """ This is needed to distinguish between extension removal vs extension upgrade. Its a workaround for waagent upgrade routine calling 'remove' on an old version before calling 'upgrade' on new extension version issue. In upgrade case, we need workspace configuration to persist when in remove case we need all the files be removed. This method will remove all the files/folders from the workspace path in Etc and Var. """ public_settings, _ = get_settings() workspaceId = public_settings.get('workspaceId') etc_remove_path = os.path.join(EtcOMSAgentPath, workspaceId) var_remove_path = os.path.join(VarOMSAgentPath, workspaceId) shutil.rmtree(etc_remove_path, True) shutil.rmtree(var_remove_path, True) hutil_log_info('Moved oms etc configuration directory and cleaned up var directory') def is_arc_installed(): """ Check if the system is on an Arc machine """ # Using systemctl to check this since Arc only supports VMs that have systemd check_arc = os.system('systemctl status himdsd 1>/dev/null 2>&1') return check_arc == 0 def get_arc_endpoint(): """ Find the endpoint for Arc Hybrid IMDS """ endpoint_filepath = '/lib/systemd/system.conf.d/azcmagent.conf' endpoint = '' try: with open(endpoint_filepath, 'r') as f: data = f.read() endpoint = data.split("\"IMDS_ENDPOINT=")[1].split("\"\n")[0] except: hutil_log_error('Unable to load Arc IMDS endpoint from {0}'.format(endpoint_filepath)) return endpoint def get_imds_endpoint(): """ Find the endpoint for IMDS, whether Arc or not """ azure_imds_endpoint = 'http://169.254.169.254/metadata/instance?api-version=2018-10-01' if (is_arc_installed()): hutil_log_info('Arc is installed, loading Arc-specific IMDS endpoint') imds_endpoint = get_arc_endpoint() if imds_endpoint: imds_endpoint += '/metadata/instance?api-version=2019-08-15' else: # Fall back to the traditional IMDS endpoint; the cloud domain and VM # resource id detection logic are resilient to failed queries to IMDS imds_endpoint = azure_imds_endpoint hutil_log_info('Falling back to default Azure IMDS endpoint') else: imds_endpoint = azure_imds_endpoint hutil_log_info('Using IMDS endpoint "{0}"'.format(imds_endpoint)) return imds_endpoint def get_vmresourceid_from_metadata(): imds_endpoint = get_imds_endpoint() req = urllib.Request(imds_endpoint) req.add_header('Metadata', 'True') try: response = json.loads(urllib.urlopen(req).read()) if ('compute' not in response or response['compute'] is None): return None # classic vm if response['compute']['vmScaleSetName']: return '/subscriptions/{0}/resourceGroups/{1}/providers/Microsoft.Compute/virtualMachineScaleSets/{2}/virtualMachines/{3}'.format(response['compute']['subscriptionId'],response['compute']['resourceGroupName'],response['compute']['vmScaleSetName'],response['compute']['name']) else: return '/subscriptions/{0}/resourceGroups/{1}/providers/Microsoft.Compute/virtualMachines/{2}'.format(response['compute']['subscriptionId'],response['compute']['resourceGroupName'],response['compute']['name']) except urlerror.HTTPError as e: hutil_log_error('Request to Metadata service URL ' \ 'failed with an HTTPError: {0}'.format(e)) hutil_log_info('Response from Metadata service: ' \ '{0}'.format(e.read())) return None except: hutil_log_error('Unexpected error from Metadata service') return None def get_azure_environment_from_imds(): imds_endpoint = get_imds_endpoint() req = urllib.Request(imds_endpoint) req.add_header('Metadata', 'True') try: response = json.loads(urllib.urlopen(req).read()) if ('compute' not in response or response['compute'] is None): return None # classic vm if ('azEnvironment' not in response['compute'] or response['compute']['azEnvironment'] is None): return None # classic vm return response['compute']['azEnvironment'] except urlerror.HTTPError as e: hutil_log_error('Request to Metadata service URL ' \ 'failed with an HTTPError: {0}'.format(e)) hutil_log_info('Response from Metadata service: ' \ '{0}'.format(e.read())) return None except: hutil_log_error('Unexpected error from Metadata service') return None def get_azure_cloud_domain(): try: environment = get_azure_environment_from_imds() if environment: for cloud, domain in CloudDomainMap.items(): if environment.lower() == cloud.lower(): hutil_log_info('Detected cloud environment "{0}" via IMDS. The domain "{1}" will be used.'.format(cloud, domain)) return domain hutil_log_info('Unknown cloud environment "{0}"'.format(environment)) except Exception as e: hutil_log_error('Failed to detect cloud environment: {0}'.format(e)) hutil_log_info('Falling back to default domain "{0}"'.format(CloudDomainMap[DefaultCloudName])) return CloudDomainMap[DefaultCloudName] def retrieve_managed_workspace(vm_resource_id): """ EnableAutomaticManagement has been set to true; the ManagedIdentity extension and the VM Resource ID are also required for the OneClick scenario Using these and the Metadata API, we will call the OMS service to determine what workspace ID and key to onboard to """ # Check for OneClick scenario requirements: if not os.path.exists(ManagedIdentityExtListeningURLPath): raise ManagedIdentityExtMissingException # Determine the Tenant ID using the Metadata API tenant_id = get_tenant_id_from_metadata_api(vm_resource_id) # Retrieve an OAuth token using the ManagedIdentity extension if tenant_id is not None: hutil_log_info('Tenant ID from Metadata API is {0}'.format(tenant_id)) access_token = get_access_token(tenant_id, OAuthTokenResource) else: return None # Query OMS service for the workspace info for onboarding if tenant_id is not None and access_token is not None: return get_workspace_info_from_oms(vm_resource_id, tenant_id, access_token) else: return None def disable(): """ Disable all OMS workspace processes on the VM. Note: disable operation times out from WAAgent at 15 minutes """ #stop the telemetry process stop_telemetry_process() # Check if the service control script is available if not os.path.exists(OMSAgentServiceScript): log_and_exit('Disable', 1, 'OMSAgent service control script {0} does' \ 'not exist. Disable cannot be called ' \ 'before install.'.format(OMSAgentServiceScript)) return 1 exit_code, output = run_command_and_log(DisableOMSAgentServiceCommand) return exit_code, output # Dictionary of operations strings to methods operations = {'Disable' : disable, 'Uninstall' : uninstall, 'Install' : install, 'Enable' : enable, # For update call we will only prepare the update by taking some backup of the state # since omsagent.py->install() will be called # everytime upgrade is done due to upgradeMode = # "UpgradeWithInstall" set in HandlerManifest 'Update' : prepare_update, 'Telemetry' : telemetry } def parse_context(operation): """ Initialize a HandlerUtil object for this operation. If the required modules have not been imported, this will return None. """ hutil = None if ('Utils.WAAgentUtil' in sys.modules and 'Utils.HandlerUtil' in sys.modules): try: logFileName = 'extension.log' if (operation == 'Telemetry'): logFileName = 'watcher.log' hutil = HUtil.HandlerUtility(waagent.Log, waagent.Error, logFileName=logFileName) hutil.do_parse_context(operation) # parse_context may throw KeyError if necessary JSON key is not # present in settings except KeyError as e: waagent_log_error('Unable to parse context with error: ' \ '{0}'.format(e)) raise ParameterMissingException return hutil def is_vm_supported_for_extension(): """ Checks if the VM this extension is running on is supported by OMSAgent Returns for platform.linux_distribution() vary widely in format, such as '7.3.1611' returned for a VM with CentOS 7, so the first provided digits must match The supported distros of the OMSAgent-for-Linux are allowed to utilize this VM extension. All other distros will get error code 51 """ supported_dists = {'redhat' : ['7', '8', '9'], 'red hat' : ['7', '8', '9'], 'rhel' : ['7', '8', '9'], # Red Hat 'centos' : ['7', '8'], # CentOS 'oracle' : ['7', '8'], 'ol': ['7', '8'], # Oracle 'debian' : ['8', '9', '10', '11'], # Debian 'ubuntu' : ['14.04', '16.04', '18.04', '20.04', '22.04'], # Ubuntu 'suse' : ['12', '15'], 'sles' : ['12', '15'], # SLES 'opensuse' : ['15'], # openSUSE 'rocky' : ['8', '9'], # Rocky 'alma' : ['8', '9'], # Alma 'amzn' : ['2'] # AWS } vm_dist, vm_ver, vm_supported = '', '', False parse_manually = False # platform commands used below aren't available after Python 3.6 if sys.version_info < (3,7): try: vm_dist, vm_ver, vm_id = platform.linux_distribution() except AttributeError: try: vm_dist, vm_ver, vm_id = platform.dist() except AttributeError: hutil_log_info("Falling back to /etc/os-release distribution parsing") # Some python versions *IF BUILT LOCALLY* (ex 3.5) give string responses (ex. 'bullseye/sid') to platform.dist() function # This causes exception in the method below. Thus adding a check to switch to manual parsing in this case try: temp_vm_ver = int(vm_ver.split('.')[0]) except: parse_manually = True else: parse_manually = True # Fallback if either of the above platform commands fail, or we switch to manual parsing if (not vm_dist and not vm_ver) or parse_manually: try: with open('/etc/os-release', 'r') as fp: for line in fp: if line.startswith('ID='): vm_dist = line.split('=')[1] vm_dist = vm_dist.split('-')[0] vm_dist = vm_dist.replace('\"', '').replace('\n', '') elif line.startswith('VERSION_ID='): vm_ver = line.split('=')[1] vm_ver = vm_ver.replace('\"', '').replace('\n', '') except: return vm_supported, 'Indeterminate operating system', '' # Find this VM distribution in the supported list for supported_dist in list(supported_dists.keys()): if not vm_dist.lower().startswith(supported_dist): continue # Check if this VM distribution version is supported vm_ver_split = vm_ver.split('.') for supported_ver in supported_dists[supported_dist]: supported_ver_split = supported_ver.split('.') # If vm_ver is at least as precise (at least as many digits) as # supported_ver and matches all the supported_ver digits, then # this VM is guaranteed to be supported vm_ver_match = True for idx, supported_ver_num in enumerate(supported_ver_split): try: supported_ver_num = int(supported_ver_num) vm_ver_num = int(vm_ver_split[idx]) except IndexError: vm_ver_match = False break if vm_ver_num is not supported_ver_num: vm_ver_match = False break if vm_ver_match: vm_supported = True break if vm_supported: break return vm_supported, vm_dist, vm_ver def exit_if_vm_not_supported(operation): """ Check if this VM distro and version are supported by the OMSAgent. If this VM is not supported, log the proper error code and exit. """ vm_supported, vm_dist, vm_ver = is_vm_supported_for_extension() if not vm_supported: log_and_exit(operation, UnsupportedOperatingSystem, 'Unsupported operating system: ' \ '{0} {1}'.format(vm_dist, vm_ver)) return 0 def exit_if_openssl_unavailable(operation): """ Check if the openssl commandline interface is available to use If not, throw error to return UnsupportedOpenSSL error code """ exit_code, output = run_get_output('which openssl', True, False) if exit_code != 0: log_and_exit(operation, UnsupportedOpenSSL, 'OpenSSL is not available') return 0 def exit_if_gpg_unavailable(operation): """ Check if gpg is available to use If not, attempt to install If install fails, throw error to return UnsupportedGpg error code """ # Check if VM is Debian (Debian 10 doesn't have gpg) vm_supp, vm_dist, _ = is_vm_supported_for_extension() if (vm_supp and (vm_dist.lower().startswith('debian'))): # Check if GPG already on VM check_exit_code, _ = run_get_output('which gpg', True, False) if check_exit_code != 0: # GPG not on VM, attempt to install hutil_log_info('GPG not found, attempting to install') exit_code, output = run_get_output(InstallExtraPackageCommandApt.format('gpg')) if exit_code != 0: log_and_exit(operation, UnsupportedGpg, 'GPG could not be installed: {0}'.format(output)) else: hutil_log_info('GPG successfully installed') else: hutil_log_info('GPG already present on VM') return 0 def check_workspace_id_and_key(workspace_id, workspace_key): """ Validate formats of workspace_id and workspace_key """ check_workspace_id(workspace_id) # Validate that workspace_key is of the correct format (base64-encoded) if workspace_key is None: raise ParameterMissingException('Workspace key must be provided') try: encoded_key = base64.b64encode(base64.b64decode(workspace_key)) if sys.version_info >= (3,): # in python 3, base64.b64encode will return bytes, so decode to str for comparison encoded_key = encoded_key.decode() if encoded_key != workspace_key: raise InvalidParameterError('Workspace key is invalid') except TypeError: raise InvalidParameterError('Workspace key is invalid') def check_workspace_id(workspace_id): """ Validate that workspace_id matches the GUID regex """ if workspace_id is None: raise ParameterMissingException('Workspace ID must be provided') search = re.compile(GUIDOnlyRegex, re.M) if not search.match(workspace_id): raise InvalidParameterError('Workspace ID is invalid') def detect_multiple_connections(workspace_id): """ If the VM already has a workspace/SCOM configured, then we should disallow a new connection when stopOnMultipleConnections is used Throw an exception in these cases: - The workspace with the given workspace_id has not been onboarded to the VM, but at least one other workspace has been - The workspace with the given workspace_id has not been onboarded to the VM, and the VM is connected to SCOM If the extension operation is connecting to an already-configured workspace, it is not a stopping case """ other_connection_exists = False if os.path.exists(OMSAdminPath): exit_code, utfoutput = run_get_output(WorkspaceCheckCommand, chk_err = False) # output may contain unicode characters not supported by ascii # for e.g., generates the following error if used without conversion: UnicodeDecodeError: 'ascii' codec can't decode byte 0xc3 in position 18: ordinal not in range(128) # default encoding in python is ascii in python < 3 if sys.version_info < (3,): output = utfoutput.decode('utf8').encode('utf8') else: output = utfoutput if output.strip().lower() != 'no workspace': for line in output.split('\n'): if workspace_id in line: hutil_log_info('The workspace to be enabled has already ' \ 'been configured on the VM before; ' \ 'continuing despite ' \ 'stopOnMultipleConnections flag') return else: # Note: if scom workspace dir is created, a line containing # "Workspace(SCOM Workspace): scom" will be here # If any other line is here, it may start sending data later other_connection_exists = True else: for dir_name, sub_dirs, files in os.walk(EtcOMSAgentPath): for sub_dir in sub_dirs: sub_dir_name = os.path.basename(sub_dir) workspace_search = re.compile(GUIDOnlyRegex, re.M) if sub_dir_name == workspace_id: hutil_log_info('The workspace to be enabled has already ' \ 'been configured on the VM before; ' \ 'continuing despite ' \ 'stopOnMultipleConnections flag') return elif (workspace_search.match(sub_dir_name) or sub_dir_name == 'scom'): other_connection_exists = True if other_connection_exists: err_msg = ('This machine is already connected to some other Log ' \ 'Analytics workspace, please set ' \ 'stopOnMultipleConnections to false in public ' \ 'settings or remove this property, so this machine ' \ 'can connect to new workspaces, also it means this ' \ 'machine will get billed multiple times for each ' \ 'workspace it report to. ' \ '(LINUXOMSAGENTEXTENSION_ERROR_MULTIPLECONNECTIONS)') # This exception will get caught by the main method raise UnwantedMultipleConnectionsException(err_msg) else: detect_scom_connection() def detect_scom_connection(): """ If these two conditions are met, then we can assume the VM is monitored by SCOM: 1. SCOMPort is open and omiserver is listening on it 2. scx certificate is signed by SCOM server To determine it check for existence of below two conditions: 1. SCOMPort is open and omiserver is listening on it: /etc/omi/conf/omiserver.conf can be parsed to determine it. 2. scx certificate is signed by SCOM server: scom cert is present @ /etc/opt/omi/ssl/omi-host-<hostname>.pem (/etc/opt/microsoft/scx/ssl/scx.pem is a softlink to this). If the VM is monitored by SCOM then issuer field of the certificate will have a value like CN=SCX-Certificate/title=<GUID>, DC=<SCOM server hostname> (e.g CN=SCX-Certificate/title=SCX94a1f46d-2ced-4739-9b6a-1f06156ca4ac, DC=NEB-OM-1502733) Otherwise, if a scom configuration directory has been created, we assume SCOM is in use """ scom_port_open = None # return when determine this is false cert_signed_by_scom = False if os.path.exists(OMSAdminPath): scom_port_open = detect_scom_using_omsadmin() if scom_port_open is False: return # If omsadmin.sh option is not available, use omiconfigeditor if (scom_port_open is None and os.path.exists(OMIConfigEditorPath) and os.path.exists(OMIServerConfPath)): scom_port_open = detect_scom_using_omiconfigeditor() if scom_port_open is False: return # If omiconfigeditor option is not available, directly parse omiserver.conf if scom_port_open is None and os.path.exists(OMIServerConfPath): scom_port_open = detect_scom_using_omiserver_conf() if scom_port_open is False: return if scom_port_open is None: hutil_log_info('SCOM port could not be determined to be open') return # Parse the certificate to determine if SCOM issued it if os.path.exists(SCOMCertPath): exit_if_openssl_unavailable('Install') cert_cmd = 'openssl x509 -in {0} -noout -text'.format(SCOMCertPath) cert_exit_code, cert_output = run_get_output(cert_cmd, chk_err = False, log_cmd = False) if cert_exit_code == 0: issuer_re = re.compile(SCOMCertIssuerRegex, re.M) if issuer_re.search(cert_output): hutil_log_info('SCOM cert exists and is signed by SCOM server') cert_signed_by_scom = True else: hutil_log_info('SCOM cert exists but is not signed by SCOM ' \ 'server') else: hutil_log_error('Error reading SCOM cert; cert could not be ' \ 'determined to be signed by SCOM server') else: hutil_log_info('SCOM cert does not exist') if scom_port_open and cert_signed_by_scom: err_msg = ('This machine may already be connected to a System ' \ 'Center Operations Manager server. Please set ' \ 'stopOnMultipleConnections to false in public settings ' \ 'or remove this property to allow connection to the Log ' \ 'Analytics workspace. ' \ '(LINUXOMSAGENTEXTENSION_ERROR_MULTIPLECONNECTIONS)') raise UnwantedMultipleConnectionsException(err_msg) def detect_scom_using_omsadmin(): """ This method assumes that OMSAdminPath exists; if packages have not been installed yet, this may not exist Returns True if omsadmin.sh indicates that SCOM port is open """ omsadmin_cmd = '{0} -o'.format(OMSAdminPath) exit_code, output = run_get_output(omsadmin_cmd, False, False) # Guard against older omsadmin.sh versions if ('illegal option' not in output.lower() and 'unknown option' not in output.lower()): if exit_code == 0: hutil_log_info('According to {0}, SCOM port is ' \ 'open'.format(omsadmin_cmd)) return True elif exit_code == 1: hutil_log_info('According to {0}, SCOM port is not ' \ 'open'.format(omsadmin_cmd)) return False def detect_scom_using_omiconfigeditor(): """ This method assumes that the relevant files exist Returns True if omiconfigeditor indicates that SCOM port is open """ omi_cmd = '{0} httpsport -q {1} < {2}'.format(OMIConfigEditorPath, SCOMPort, OMIServerConfPath) exit_code, output = run_get_output(omi_cmd, False, False) # Guard against older omiconfigeditor versions if ('illegal option' not in output.lower() and 'unknown option' not in output.lower()): if exit_code == 0: hutil_log_info('According to {0}, SCOM port is ' \ 'open'.format(omi_cmd)) return True elif exit_code == 1: hutil_log_info('According to {0}, SCOM port is not ' \ 'open'.format(omi_cmd)) return False def detect_scom_using_omiserver_conf(): """ This method assumes that the relevant files exist Returns True if omiserver.conf indicates that SCOM port is open """ with open(OMIServerConfPath, 'r') as omiserver_file: omiserver_txt = omiserver_file.read() httpsport_search = r'^[\s]*httpsport[\s]*=(.*)$' httpsport_re = re.compile(httpsport_search, re.M) httpsport_matches = httpsport_re.search(omiserver_txt) if (httpsport_matches is not None and httpsport_matches.group(1) is not None): ports = httpsport_matches.group(1) ports = ports.replace(',', ' ') ports_list = ports.split(' ') if str(SCOMPort) in ports_list: hutil_log_info('SCOM port is listed in ' \ '{0}'.format(OMIServerConfPath)) return True else: hutil_log_info('SCOM port is not listed in ' \ '{0}'.format(OMIServerConfPath)) else: hutil_log_info('SCOM port is not listed in ' \ '{0}'.format(OMIServerConfPath)) return False def run_command_and_log(cmd, check_error = True, log_cmd = True): """ Run the provided shell command and log its output, including stdout and stderr. The output should not contain any PII, but the command might. In this case, log_cmd should be set to False. """ exit_code, output = run_get_output(cmd, check_error, log_cmd) if log_cmd: hutil_log_info('Output of command "{0}": \n{1}'.format(cmd.rstrip(), output)) else: hutil_log_info('Output: \n{0}'.format(output)) # For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log if exit_code == 17: if "Failed dependencies:" in output: # 52 is the exit code for missing dependency # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "Installation failed due to missing dependencies. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif "waiting for transaction lock" in output or "dpkg: error processing package systemd" in output or "dpkg-deb" in output or "dpkg:" in output: # 52 is the exit code for missing dependency # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "There seems to be an issue in your package manager dpkg or rpm being in lock state. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif "Errors were encountered while processing:" in output: # 52 is the exit code for missing dependency # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "There seems to be an issue while processing triggers in systemd. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif "Cannot allocate memory" in output: # 52 is the exit code for missing dependency # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "There seems to be insufficient memory for the installation. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif exit_code == 19: if "rpmdb" in output or "cannot open Packages database" in output or "dpkg (subprocess): cannot set security execution context for maintainer script" in output or "is locked by another process" in output: # OMI (19) happens to be the first package we install and if we get rpmdb failures, its a system issue # 52 is the exit code for missing dependency i.e. rpmdb, libc6 or libpam-runtime # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "There seems to be an issue in your package manager dpkg or rpm being in lock state. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif "libc6 is not installed" in output or "libpam-runtime is not installed" in output or "exited with status 52" in output or "/bin/sh is needed" in output: # OMI (19) happens to be the first package we install and if we get rpmdb failures, its a system issue # 52 is the exit code for missing dependency i.e. rpmdb, libc6 or libpam-runtime # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "Installation failed due to missing dependencies. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif exit_code == 33: if "Permission denied" in output: # Enable failures # 52 is the exit code for missing dependency. # DSC metaconfig generation failure due to permissions. # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "Installation failed due to insufficient permissions. Please ensure omsagent user is part of the sudoer file and has sufficient permissions, and omsconfig MetaConfig.mof can be generated. For details, check logs in /var/opt/microsoft/omsconfig/omsconfig.log and /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif exit_code == 18: # Install failures # DSC install failure # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr output = "Installation failed due to omsconfig package not being able to install. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif exit_code == 5: if "Reason: InvalidWorkspaceKey" in output or "Reason: MissingHeader" in output: # Enable failures # 53 is the exit code for configuration errors # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 53 output = "Installation failed due to incorrect workspace key. Please check if the workspace key is correct. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" elif exit_code == 8: if "Check the correctness of the workspace ID and shared key" in output or "internet connectivity" in output: # Enable failures # 53 is the exit code for configuration errors # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 53 output = "Installation failed due to curl error while onboarding. Please check the internet connectivity or the workspace key. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" if exit_code != 0 and exit_code != 52: if "dpkg:" in output or "dpkg :" in output or "rpmdb:" in output or "rpm.lock" in output or "locked by another process" in output: # If we get rpmdb failures, its a system issue. # 52 is the exit code for missing dependency i.e. rpmdb, libc6 or libpam-runtime # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "There seems to be an issue in your package manager dpkg or rpm being in lock state when installing omsagent bundle for one of the dependencies. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" if "conflicts with file from package" in output or "Failed dependencies:" in output or "Please install curl" in output or "is needed by" in output or "check_version_installable" in output or "Error: curl was not installed" in output or "Please install the ctypes package" in output or "gpg is not installed" in output: # If we get rpmdb failures, its a system issue # 52 is the exit code for missing dependency i.e. rpmdb, libc6 or libpam-runtime # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "Installation failed due to missing dependencies. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" if "Permission denied" in output: # Install/Enable failures # 52 is the exit code for missing dependency. # https://github.com/Azure/azure-marketplace/wiki/Extension-Build-Notes-Best-Practices#error-codes-and-messages-output-to-stderr exit_code = 52 output = "Installation failed due to insufficient permissions. Please ensure omsagent user is part of the sudoer file and has sufficient permissions to install and onboard. For details, check logs in /var/log/azure/Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux/extension.log" return exit_code, output def run_command_with_retries(cmd, retries, retry_check, final_check = None, check_error = True, log_cmd = True, initial_sleep_time = InitialRetrySleepSeconds, sleep_increase_factor = 1): """ Caller provides a method, retry_check, to use to determine if a retry should be performed. This must be a function with two parameters: exit_code and output The final_check can be provided as a method to perform a final check after retries have been exhausted Logic used: will retry up to retries times with initial_sleep_time in between tries If the retry_check returns True for retry_verbosely, we will try cmd with the standard -v verbose flag added """ try_count = 0 sleep_time = initial_sleep_time run_cmd = cmd run_verbosely = False while try_count <= retries: if run_verbosely: run_cmd = cmd + ' -v' exit_code, output = run_command_and_log(run_cmd, check_error, log_cmd) should_retry, retry_message, run_verbosely = retry_check(exit_code, output) if not should_retry: break try_count += 1 hutil_log_info(retry_message) time.sleep(sleep_time) sleep_time *= sleep_increase_factor if final_check is not None: exit_code = final_check(exit_code, output) return exit_code def run_command_with_retries_output(cmd, retries, retry_check, final_check = None, check_error = True, log_cmd = True, initial_sleep_time = InitialRetrySleepSeconds, sleep_increase_factor = 1): """ Caller provides a method, retry_check, to use to determine if a retry should be performed. This must be a function with two parameters: exit_code and output The final_check can be provided as a method to perform a final check after retries have been exhausted Logic used: will retry up to retries times with initial_sleep_time in between tries If the retry_check retuns True for retry_verbosely, we will try cmd with the standard -v verbose flag added """ try_count = 0 sleep_time = initial_sleep_time run_cmd = cmd run_verbosely = False while try_count <= retries: if run_verbosely: run_cmd = cmd + ' -v' exit_code, output = run_command_and_log(run_cmd, check_error, log_cmd) should_retry, retry_message, run_verbosely = retry_check(exit_code, output) if not should_retry: break try_count += 1 hutil_log_info(retry_message) time.sleep(sleep_time) sleep_time *= sleep_increase_factor if final_check is not None: exit_code = final_check(exit_code, output) return exit_code, output def is_dpkg_locked(exit_code, output): """ If dpkg is locked, the output will contain a message similar to 'dpkg status database is locked by another process' """ if exit_code != 0: dpkg_locked_search = r'^.*dpkg.+lock.*$' dpkg_locked_re = re.compile(dpkg_locked_search, re.M) if dpkg_locked_re.search(output): return True return False def was_curl_found(exit_code, output): """ Returns false if exit_code indicates that curl was not installed; this can occur when package lists need to be updated, or when some archives are out-of-date """ if exit_code is InstallErrorCurlNotInstalled: return False return True def retry_skip(exit_code, output): """ skip retires """ return False, '', False def retry_if_dpkg_locked_or_curl_is_not_found(exit_code, output): """ Some commands fail because the package manager is locked (apt-get/dpkg only); this will allow retries on failing commands. Sometimes curl's dependencies (i.e. libcurl) are not installed; if this is the case on a VM with apt-get, 'apt-get -f install' should be run Sometimes curl is not installed and is also not found in the package list; if this is the case on a VM with apt-get, update the package list """ retry_verbosely = False dpkg_locked = is_dpkg_locked(exit_code, output) curl_found = was_curl_found(exit_code, output) apt_get_exit_code, apt_get_output = run_get_output('which apt-get', chk_err = False, log_cmd = False) if dpkg_locked: return True, 'Retrying command because package manager is locked.', \ retry_verbosely elif (not curl_found and apt_get_exit_code == 0 and ('apt-get -f install' in output or 'Unmet dependencies' in output.lower())): hutil_log_info('Installing all dependencies of curl:') run_command_and_log('apt-get -f install') return True, 'Retrying command because curl and its dependencies ' \ 'needed to be installed', retry_verbosely elif not curl_found and apt_get_exit_code == 0: hutil_log_info('Updating package lists to make curl available') run_command_and_log('apt-get update') return True, 'Retrying command because package lists needed to be ' \ 'updated', retry_verbosely else: return False, '', False def final_check_if_dpkg_locked(exit_code, output): """ If dpkg is still locked after the retries, we want to return a specific error code """ dpkg_locked = is_dpkg_locked(exit_code, output) if dpkg_locked: exit_code = DPKGLockedErrorCode return exit_code def retry_onboarding(exit_code, output): """ Retry under any of these conditions: - If the onboarding request returns 403: this may indicate that the agent GUID and certificate should be re-generated - If the onboarding request returns a different non-200 code: the OMS service may be temporarily unavailable - If the onboarding curl command returns an unaccounted-for error code, we should retry with verbose logging """ retry_verbosely = False if exit_code is EnableErrorOMSReturned403: return True, 'Retrying the onboarding command to attempt generating ' \ 'a new agent ID and certificate.', retry_verbosely elif exit_code is EnableErrorOMSReturnedNon200: return True, 'Retrying; the OMS service may be temporarily ' \ 'unavailable.', retry_verbosely elif exit_code is EnableErrorOnboarding: return True, 'Retrying with verbose logging.', True return False, '', False def raise_if_no_internet(exit_code, output): """ Raise the CannotConnectToOMSException exception if the onboarding script returns the error code to indicate that the OMS service can't be resolved """ if exit_code is EnableErrorResolvingHost: raise CannotConnectToOMSException return exit_code def get_settings(): """ Retrieve the configuration for this extension operation """ global SettingsDict public_settings = None protected_settings = None if HUtilObject is not None: public_settings = HUtilObject.get_public_settings() protected_settings = HUtilObject.get_protected_settings() elif SettingsDict is not None: public_settings = SettingsDict['public_settings'] protected_settings = SettingsDict['protected_settings'] else: SettingsDict = {} handler_env = get_handler_env() try: config_dir = str(handler_env['handlerEnvironment']['configFolder']) except: config_dir = os.path.join(os.getcwd(), 'config') seq_no = get_latest_seq_no() settings_path = os.path.join(config_dir, '{0}.settings'.format(seq_no)) try: with open(settings_path, 'r') as settings_file: settings_txt = settings_file.read() settings = json.loads(settings_txt) h_settings = settings['runtimeSettings'][0]['handlerSettings'] public_settings = h_settings['publicSettings'] SettingsDict['public_settings'] = public_settings except: hutil_log_error('Unable to load handler settings from ' \ '{0}'.format(settings_path)) if ('protectedSettings' in h_settings and 'protectedSettingsCertThumbprint' in h_settings and h_settings['protectedSettings'] is not None and h_settings['protectedSettingsCertThumbprint'] is not None): encoded_settings = h_settings['protectedSettings'] settings_thumbprint = h_settings['protectedSettingsCertThumbprint'] encoded_cert_path = os.path.join('/var/lib/waagent', '{0}.crt'.format( settings_thumbprint)) encoded_key_path = os.path.join('/var/lib/waagent', '{0}.prv'.format( settings_thumbprint)) decoded_settings = base64.standard_b64decode(encoded_settings) decrypt_cmd = 'openssl smime -inform DER -decrypt -recip {0} ' \ '-inkey {1}'.format(encoded_cert_path, encoded_key_path) try: session = subprocess.Popen([decrypt_cmd], shell = True, stdin = subprocess.PIPE, stderr = subprocess.STDOUT, stdout = subprocess.PIPE) output = session.communicate(decoded_settings) except OSError: pass protected_settings_str = output[0] if protected_settings_str is None: log_and_exit('Enable', 1, 'Failed decrypting ' \ 'protectedSettings') protected_settings = '' try: protected_settings = json.loads(protected_settings_str) except: hutil_log_error('JSON exception decoding protected settings') SettingsDict['protected_settings'] = protected_settings return public_settings, protected_settings def update_status_file(operation, exit_code, exit_status, message): """ Mimic HandlerUtil method do_status_report in case hutil method is not available Write status to status file """ handler_env = get_handler_env() try: extension_version = str(handler_env['version']) status_dir = str(handler_env['handlerEnvironment']['statusFolder']) except: extension_version = "1.0" status_dir = os.path.join(os.getcwd(), 'status') status_txt = [{ "version" : extension_version, "timestampUTC" : time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "status" : { "name" : "Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux", "operation" : operation, "status" : exit_status, "code" : exit_code, "formattedMessage" : { "lang" : "en-US", "message" : message } } }] status_json = json.dumps(status_txt) # Find the most recently changed config file and then use the # corresponding status file latest_seq_no = get_latest_seq_no() status_path = os.path.join(status_dir, '{0}.status'.format(latest_seq_no)) status_tmp = '{0}.tmp'.format(status_path) with open(status_tmp, 'w+') as tmp_file: tmp_file.write(status_json) os.rename(status_tmp, status_path) def get_handler_env(): """ Set and retrieve the contents of HandlerEnvironment.json as JSON """ global HandlerEnvironment if HandlerEnvironment is None: handler_env_path = os.path.join(os.getcwd(), 'HandlerEnvironment.json') try: with open(handler_env_path, 'r') as handler_env_file: handler_env_txt = handler_env_file.read() handler_env = json.loads(handler_env_txt) if type(handler_env) == list: handler_env = handler_env[0] HandlerEnvironment = handler_env except Exception as e: waagent_log_error(str(e)) return HandlerEnvironment def get_latest_seq_no(): """ Determine the latest operation settings number to use """ global SettingsSequenceNumber if SettingsSequenceNumber is None: handler_env = get_handler_env() try: config_dir = str(handler_env['handlerEnvironment']['configFolder']) except: config_dir = os.path.join(os.getcwd(), 'config') latest_seq_no = -1 cur_seq_no = -1 latest_time = None try: for dir_name, sub_dirs, file_names in os.walk(config_dir): for file_name in file_names: file_basename = os.path.basename(file_name) match = re.match(r'[0-9]{1,10}\.settings', file_basename) if match is None: continue cur_seq_no = int(file_basename.split('.')[0]) file_path = os.path.join(config_dir, file_name) cur_time = os.path.getmtime(file_path) if latest_time is None or cur_time > latest_time: latest_time = cur_time latest_seq_no = cur_seq_no except: pass if latest_seq_no < 0: latest_seq_no = 0 SettingsSequenceNumber = latest_seq_no return SettingsSequenceNumber def run_get_output(cmd, chk_err = False, log_cmd = True): """ Mimic waagent mothod RunGetOutput in case waagent is not available Run shell command and return exit code and output """ if 'Utils.WAAgentUtil' in sys.modules: # WALinuxAgent-2.0.14 allows only 2 parameters for RunGetOutput # If checking the number of parameters fails, pass 2 try: sig = inspect.signature(waagent.RunGetOutput) params = sig.parameters waagent_params = len(params) except: try: spec = inspect.getargspec(waagent.RunGetOutput) params = spec.args waagent_params = len(params) except: waagent_params = 2 if waagent_params >= 3: exit_code, output = waagent.RunGetOutput(cmd, chk_err, log_cmd) else: exit_code, output = waagent.RunGetOutput(cmd, chk_err) else: try: output = subprocess.check_output(cmd, stderr = subprocess.STDOUT, shell = True) output = output.decode('latin-1') exit_code = 0 except subprocess.CalledProcessError as e: exit_code = e.returncode output = e.output.decode('latin-1') output = output.encode('utf-8', 'ignore') # On python 3, encode returns a byte object, so we must decode back to a string if sys.version_info >= (3,): output = output.decode() return exit_code, output.strip() def get_tenant_id_from_metadata_api(vm_resource_id): """ Retrieve the Tenant ID using the Metadata API of the VM resource ID Since we have not authenticated, the Metadata API will throw a 401, but the headers of the 401 response will contain the tenant ID """ tenant_id = None metadata_endpoint = get_metadata_api_endpoint(vm_resource_id) metadata_request = urllib.Request(metadata_endpoint) try: # This request should fail with code 401 metadata_response = urllib.urlopen(metadata_request) hutil_log_info('Request to Metadata API did not fail as expected; ' \ 'attempting to use headers from response to ' \ 'determine Tenant ID') metadata_headers = metadata_response.headers except urlerror.HTTPError as e: metadata_headers = e.headers if metadata_headers is not None and 'WWW-Authenticate' in metadata_headers: auth_header = metadata_headers['WWW-Authenticate'] auth_header_regex = r'authorization_uri=\"https:\/\/login\.windows\.net/(' + GUIDRegex + ')\"' auth_header_search = re.compile(auth_header_regex) auth_header_matches = auth_header_search.search(auth_header) if not auth_header_matches: raise MetadataAPIException('The WWW-Authenticate header in the ' \ 'response does not contain expected ' \ 'authorization_uri format') else: tenant_id = auth_header_matches.group(1) else: raise MetadataAPIException('Expected information from Metadata API ' \ 'is not present') return tenant_id def get_metadata_api_endpoint(vm_resource_id): """ Extrapolate Metadata API endpoint from VM Resource ID Example VM resource ID: /subscriptions/306ee7f1-3d0a-4605-9f39-ff253cc02708/resourceGroups/LinuxExtVMResourceGroup/providers/Microsoft.Compute/virtualMachines/lagalbraOCUb16C Corresponding example endpoint: https://management.azure.com/subscriptions/306ee7f1-3d0a-4605-9f39-ff253cc02708/resourceGroups/LinuxExtVMResourceGroup?api-version=2016-09-01 """ # Will match for ARM and Classic VMs, Availability Sets, VM Scale Sets vm_resource_id_regex = r'^\/subscriptions\/(' + GUIDRegex + ')\/' \ 'resourceGroups\/([^\/]+)\/providers\/Microsoft' \ '\.(?:Classic){0,1}Compute\/(?:virtualMachines|' \ 'availabilitySets|virtualMachineScaleSets)' \ '\/[^\/]+$' vm_resource_id_search = re.compile(vm_resource_id_regex, re.M) vm_resource_id_matches = vm_resource_id_search.search(vm_resource_id) if not vm_resource_id_matches: raise InvalidParameterError('VM Resource ID is invalid') else: subscription_id = vm_resource_id_matches.group(1) resource_group = vm_resource_id_matches.group(2) metadata_url = 'https://management.azure.com/subscriptions/{0}' \ '/resourceGroups/{1}'.format(subscription_id, resource_group) metadata_data = urlparse.urlencode({'api-version' : '2016-09-01'}) metadata_endpoint = '{0}?{1}'.format(metadata_url, metadata_data) return metadata_endpoint def get_access_token(tenant_id, resource): """ Retrieve an OAuth token by sending an OAuth2 token exchange request to the local URL that the ManagedIdentity extension is listening to """ # Extract the endpoint that the ManagedIdentity extension is listening on with open(ManagedIdentityExtListeningURLPath, 'r') as listening_file: listening_settings_txt = listening_file.read() try: listening_settings = json.loads(listening_settings_txt) listening_url = listening_settings['url'] except: raise ManagedIdentityExtException('Could not extract listening URL ' \ 'from settings file') # Send an OAuth token exchange request oauth_data = {'authority' : 'https://login.microsoftonline.com/' \ '{0}'.format(tenant_id), 'resource' : resource } oauth_request = urllib.Request(listening_url + '/oauth2/token', urlparse.urlencode(oauth_data)) oauth_request.add_header('Metadata', 'true') try: oauth_response = urllib.urlopen(oauth_request) oauth_response_txt = oauth_response.read() except urlerror.HTTPError as e: hutil_log_error('Request to ManagedIdentity extension listening URL ' \ 'failed with an HTTPError: {0}'.format(e)) hutil_log_info('Response from ManagedIdentity extension: ' \ '{0}'.format(e.read())) raise ManagedIdentityExtException('Request to listening URL failed ' \ 'with HTTPError {0}'.format(e)) except: raise ManagedIdentityExtException('Unexpected error from request to ' \ 'listening URL') try: oauth_response_json = json.loads(oauth_response_txt) except: raise ManagedIdentityExtException('Error parsing JSON from ' \ 'listening URL response') if (oauth_response_json is not None and 'access_token' in oauth_response_json): return oauth_response_json['access_token'] else: raise ManagedIdentityExtException('Could not retrieve access token ' \ 'in the listening URL response') def get_workspace_info_from_oms(vm_resource_id, tenant_id, access_token): """ Send a request to the OMS service with the VM information to determine the workspace the OMSAgent should onboard to """ oms_data = {'ResourceId' : vm_resource_id, 'TenantId' : tenant_id, 'JwtToken' : access_token } oms_request_json = json.dumps(oms_data) oms_request = urllib.Request(OMSServiceValidationEndpoint) oms_request.add_header('Content-Type', 'application/json') retries = 5 initial_sleep_time = AutoManagedWorkspaceCreationSleepSeconds sleep_increase_factor = 1 try_count = 0 sleep_time = initial_sleep_time # Workspace may not be provisioned yet; sleep and retry if # provisioning has been accepted while try_count <= retries: try: oms_response = urllib.urlopen(oms_request, oms_request_json) oms_response_txt = oms_response.read() except urlerror.HTTPError as e: hutil_log_error('Request to OMS threw HTTPError: {0}'.format(e)) hutil_log_info('Response from OMS: {0}'.format(e.read())) raise OMSServiceOneClickException('ValidateMachineIdentity ' \ 'request returned an error ' \ 'HTTP code: {0}'.format(e)) except: raise OMSServiceOneClickException('Unexpected error from ' \ 'ValidateMachineIdentity ' \ 'request') should_retry = retry_get_workspace_info_from_oms(oms_response) if not should_retry: # TESTED break elif try_count == retries: # TESTED hutil_log_error('Retries for ValidateMachineIdentity request ran ' \ 'out: required workspace information cannot be ' \ 'extracted') raise OneClickException('Workspace provisioning did not complete ' \ 'within the allotted time') # TESTED try_count += 1 time.sleep(sleep_time) sleep_time *= sleep_increase_factor if not oms_response_txt: raise OMSServiceOneClickException('Body from ValidateMachineIdentity ' \ 'response is empty; required ' \ 'workspace information cannot be ' \ 'extracted') try: oms_response_json = json.loads(oms_response_txt) except: raise OMSServiceOneClickException('Error parsing JSON from ' \ 'ValidateMachineIdentity response') if (oms_response_json is not None and 'WorkspaceId' in oms_response_json and 'WorkspaceKey' in oms_response_json): return oms_response_json else: hutil_log_error('Could not retrieve both workspace ID and key from ' \ 'the OMS service response {0}; cannot determine ' \ 'workspace ID and key'.format(oms_response_json)) raise OMSServiceOneClickException('Required workspace information ' \ 'was not found in the ' \ 'ValidateMachineIdentity response') def retry_get_workspace_info_from_oms(oms_response): """ Return True to retry if the response from OMS for the ValidateMachineIdentity request incidates that the request has been accepted, but the managed workspace is still being provisioned """ try: oms_response_http_code = oms_response.getcode() except: hutil_log_error('Unable to get HTTP code from OMS repsonse') return False if (oms_response_http_code == 202 or oms_response_http_code == 204 or oms_response_http_code == 404): hutil_log_info('Retrying ValidateMachineIdentity OMS request ' \ 'because workspace is still being provisioned; HTTP ' \ 'code from OMS is {0}'.format(oms_response_http_code)) return True else: hutil_log_info('Workspace is provisioned; HTTP code from OMS is ' \ '{0}'.format(oms_response_http_code)) return False def init_waagent_logger(): """ Initialize waagent logger If waagent has not been imported, catch the exception """ try: waagent.LoggerInit('/var/log/waagent.log', '/dev/stdout', True) except Exception as e: print('Unable to initialize waagent log because of exception ' \ '{0}'.format(e)) def waagent_log_info(message): """ Log informational message, being cautious of possibility that waagent may not be imported """ if 'Utils.WAAgentUtil' in sys.modules: waagent.Log(message) else: print('Info: {0}'.format(message)) def waagent_log_error(message): """ Log error message, being cautious of possibility that waagent may not be imported """ if 'Utils.WAAgentUtil' in sys.modules: waagent.Error(message) else: print('Error: {0}'.format(message)) def hutil_log_info(message): """ Log informational message, being cautious of possibility that hutil may not be imported and configured """ if HUtilObject is not None: HUtilObject.log(message) else: print('Info: {0}'.format(message)) def hutil_log_error(message): """ Log error message, being cautious of possibility that hutil may not be imported and configured """ if HUtilObject is not None: HUtilObject.error(message) else: print('Error: {0}'.format(message)) def log_and_exit(operation, exit_code = 1, message = ''): """ Log the exit message and perform the exit """ if exit_code == 0: waagent_log_info(message) hutil_log_info(message) exit_status = 'success' else: waagent_log_error(message) hutil_log_error(message) exit_status = 'failed' if HUtilObject is not None: HUtilObject.do_exit(exit_code, operation, exit_status, str(exit_code), message) else: update_status_file(operation, str(exit_code), exit_status, message) sys.exit(exit_code) # Exceptions # If these exceptions are expected to be caught by the main method, they # include an error_code field with an integer with which to exit from main class OmsAgentForLinuxException(Exception): """ Base exception class for all exceptions; as such, its error code is the basic error code traditionally returned in Linux: 1 """ error_code = 1 def get_error_message(self, operation): """ Return a descriptive error message based on this type of exception """ return '{0} failed with exit code {1}'.format(operation, self.error_code) class ParameterMissingException(OmsAgentForLinuxException): """ There is a missing parameter for the OmsAgentForLinux Extension """ error_code = MissingorInvalidParameterErrorCode def get_error_message(self, operation): return '{0} failed due to a missing parameter: {1}'.format(operation, self) class InvalidParameterError(OmsAgentForLinuxException): """ There is an invalid parameter for the OmsAgentForLinux Extension ex. Workspace ID does not match GUID regex """ error_code = MissingorInvalidParameterErrorCode def get_error_message(self, operation): return '{0} failed due to an invalid parameter: {1}'.format(operation, self) class UnwantedMultipleConnectionsException(OmsAgentForLinuxException): """ This VM is already connected to a different Log Analytics workspace and stopOnMultipleConnections is set to true """ error_code = UnwantedMultipleConnectionsErrorCode def get_error_message(self, operation): return '{0} failed due to multiple connections: {1}'.format(operation, self) class CannotConnectToOMSException(OmsAgentForLinuxException): """ The OMSAgent cannot connect to the OMS service """ error_code = CannotConnectToOMSErrorCode # error code to indicate no internet access def get_error_message(self, operation): return 'The agent could not connect to the Microsoft Operations ' \ 'Management Suite service. Please check that the system ' \ 'either has Internet access, or that a valid HTTP proxy has ' \ 'been configured for the agent. Please also check the ' \ 'correctness of the workspace ID.' class OneClickException(OmsAgentForLinuxException): """ A generic exception for OneClick-related issues """ error_code = OneClickErrorCode def get_error_message(self, operation): return 'Encountered an issue related to the OneClick scenario: ' \ '{0}'.format(self) class ManagedIdentityExtMissingException(OneClickException): """ This extension being present is required for the OneClick scenario """ error_code = ManagedIdentityExtMissingErrorCode def get_error_message(self, operation): return 'The ManagedIdentity extension is required to be installed ' \ 'for Automatic Management to be enabled. Please set ' \ 'EnableAutomaticManagement to false in public settings or ' \ 'install the ManagedIdentityExtensionForLinux Azure VM ' \ 'extension.' class ManagedIdentityExtException(OneClickException): """ Thrown when we encounter an issue with ManagedIdentityExtensionForLinux """ error_code = ManagedIdentityExtErrorCode def get_error_message(self, operation): return 'Encountered an issue with the ManagedIdentity extension: ' \ '{0}'.format(self) class MetadataAPIException(OneClickException): """ Thrown when we encounter an issue with Metadata API """ error_code = MetadataAPIErrorCode def get_error_message(self, operation): return 'Encountered an issue with the Metadata API: {0}'.format(self) class OMSServiceOneClickException(OneClickException): """ Thrown when prerequisites were satisfied but could not retrieve the managed workspace information from OMS service """ error_code = OMSServiceOneClickErrorCode def get_error_message(self, operation): return 'Encountered an issue with the OMS service: ' \ '{0}'.format(self) if __name__ == '__main__' : main() ================================================ FILE: OmsAgent/omsagent.version ================================================ # Do NOT update the values here; CDPx will use the ones # defined in Build-OMS-Agent-for-Linux/omsagent.version OMS_VERSION_MAJOR=0 OMS_VERSION_MINOR=0 OMS_VERSION_PATCH_EXTENSION=0 OMS_VERSION_PATCH_SHELL_BUNDLE=0 OMS_VERSION_BUILDNR_SHELL_BUNDLE=0 OMS_VERSION_DATE=0 OMS_EXTENSION_VERSION="$OMS_VERSION_MAJOR.$OMS_VERSION_MINOR.$OMS_VERSION_PATCH_EXTENSION" OMS_SHELL_BUNDLE_VERSION="$OMS_VERSION_MAJOR.$OMS_VERSION_MINOR.$OMS_VERSION_PATCH_SHELL_BUNDLE-$OMS_VERSION_BUILDNR_SHELL_BUNDLE" ================================================ FILE: OmsAgent/omsagent_shim.sh ================================================ #!/usr/bin/env bash # The entry point for the OMS extension through which the correct python version (if any) is used to invoke omsagent.py. # We default to python2 and always invoke with the versioned python command to accomodate the RHEL 8+ python strategy. # Control arguments passed to the shim are redirected to omsagent.py without validation. COMMAND="./omsagent.py" PYTHON="" ARG="$@" function find_python() { local python_exec_command=$1 if command -v python2 >/dev/null 2>&1 ; then eval ${python_exec_command}="python2" elif command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" fi } find_python PYTHON if [ -z "$PYTHON" ] then echo "No Python interpreter found, which is an OMS extension dependency. Please install either Python 2 or 3." >&2 exit 52 # Missing Dependency else ${PYTHON} --version 2>&1 fi PYTHONPATH=${PYTHONPATH} ${PYTHON} ${COMMAND} ${ARG} exit $? ================================================ FILE: OmsAgent/packaging.sh ================================================ #! /bin/bash set -e source omsagent.version usage() { local basename=`basename $0` echo "usage: ./$basename <path to omsagent-<version>.universal.x64{.sh, .sha256sums, .asc}> [path for zip output]" } input_path=$1 output_path=$2 PACKAGE_NAME="oms$OMS_EXTENSION_VERSION.zip" if [[ "$1" == "--help" ]]; then usage exit 0 elif [[ ! -d $input_path ]]; then echo "OMS files path '$input_path' not found" usage exit 1 fi if [[ "$output_path" == "" ]]; then output_path="../" fi # Packaging starts here cp -r ../Utils . cp ../Common/WALinuxAgent-2.0.16/waagent . # cleanup packages # copy shell bundle to packages/ cp $input_path/omsagent-$OMS_SHELL_BUNDLE_VERSION.universal.x64.* packages/ # sync the file copy sync if [[ -f $output_path/$PACKAGE_NAME ]]; then echo "Removing existing $PACKAGE_NAME ..." rm -f $output_path/$PACKAGE_NAME fi echo "Packaging extension $PACKAGE_NAME to $output_path" excluded_files="omsagent.version packaging.sh apply_version.sh update_version.sh" zip -r $output_path/$PACKAGE_NAME * -x $excluded_files "./test/*" "./extension-test/*" "./references" # cleanup newly added dir or files rm -rf Utils/ waagent ================================================ FILE: OmsAgent/references ================================================ Utils/ ================================================ FILE: OmsAgent/test/MockUtil.py ================================================ #!/usr/bin/env python # #OmsAgent extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class MockUtil(): def __init__(self, test): self.test = test def get_log_dir(self): return "/tmp" def log(self, msg): print(msg) def error(self, msg): print(msg) def get_seq_no(self): return "0" def do_status_report(self, operation, status, status_code, message): self.test.assertNotEqual(None, message) def do_exit(self,exit_code,operation,status,code,message): self.test.assertNotEqual(None, message) ================================================ FILE: OmsAgent/test/env.py ================================================ #!/usr/bin/env python # #OmsAgent extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os # append installer directory to sys.path root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root) ================================================ FILE: OmsAgent/test/test_install.py ================================================ #!/usr/bin/env python # # OmsAgent extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import env import omsagent as oa import os from MockUtil import MockUtil os.chdir(env.root) class TestInstall(unittest.TestCase): def test_install(self): hutil = MockUtil(self) oa.install(hutil) if __name__ == '__main__': unittest.main() ================================================ FILE: OmsAgent/update_version.sh ================================================ #! /bin/bash set -x if [[ "$1" == "--help" ]]; then echo "update_version.sh <MAJOR> <MINOR> <PATCH> <BUILDNR>" exit 0 fi UPDATE_DATE=`date +%Y%m%d` OMS_BUILDVERSION_MAJOR=$1 OMS_BUILDVERSION_MINOR=$2 OMS_BUILDVERSION_PATCH=$3 OMS_BUILDVERSION_BUILDNR=$4 if [[ "$OMS_BUILDVERSION_MAJOR" == "" ]]; then echo "MAJOR version is empty" exit 1 fi if [[ "$OMS_BUILDVERSION_MINOR" == "" ]]; then echo "MINOR version is empty" exit 1 fi if [[ "$OMS_BUILDVERSION_PATCH" == "" ]]; then echo "PATH version is empty" exit 1 fi if [[ "$OMS_BUILDVERSION_BUILDNR" == "" ]]; then echo "BUILDNR version is empty" exit 1 fi sed -i "s/^OMS_VERSION_MAJOR=.*$/OMS_VERSION_MAJOR=$OMS_BUILDVERSION_MAJOR/" omsagent.version sed -i "s/^OMS_VERSION_MINOR=.*$/OMS_VERSION_MINOR=$OMS_BUILDVERSION_MINOR/" omsagent.version sed -i "s/^OMS_VERSION_PATCH_EXTENSION=.*$/OMS_VERSION_PATCH_EXTENSION=$OMS_BUILDVERSION_PATCH/" omsagent.version sed -i "s/^OMS_VERSION_PATCH_SHELL_BUNDLE=.*$/OMS_VERSION_PATCH_SHELL_BUNDLE=$OMS_BUILDVERSION_PATCH/" omsagent.version sed -i "s/^OMS_VERSION_BUILDNR_SHELL_BUNDLE=.*$/OMS_VERSION_BUILDNR_SHELL_BUNDLE=$OMS_BUILDVERSION_BUILDNR/" omsagent.version sed -i "s/^OMS_VERSION_DATE=.*$/OMS_VERSION_DATE=$UPDATE_DATE/" omsagent.version ================================================ FILE: OmsAgent/watcherutil.py ================================================ #!/usr/bin/env python # # OmsAgentForLinux Extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import os import io import datetime from datetime import datetime, timedelta import time import string import traceback import shutil import sys import json import uuid from threading import Thread import re import hashlib from omsagent import run_command_and_log from omsagent import RestartOMSAgentServiceCommand """ Write now hardcode memory threshold to watch for to 20 %. If agent is using more than 20% of memory it is definitely very high. In future we may want to set it based on customer configuration. """ # Constants. MemoryThresholdToWatchFor = 20 OmsAgentPidFile = "/var/opt/microsoft/omsagent/run/omsagent.pid" OmsAgentLogFile = "/var/opt/microsoft/omsagent/log/omsagent.log" reg_ex = re.compile('([0-9]{4}-[0-9]{2}-[0-9]{2}.*)\[(\w+)\]:(.*)') maxMessageSize = 100 OMSExtensionVersion = '1.13.19' """ We can add to the list below with more error messages to identify non recoverable errors. """ ErrorStatements = ["Errono::ENOSPC error=", "Fatal error, can not clear buffer file", "No space left on the device"] class SelfMonitorInfo(object): """ Class to hold self mon info for omsagent. """ def __init__(self): self._consecutive_error_count = 0 self._last_reset_success = True self._error_count = 0 self._memory_used_in_percent = 0 self._consecutive_high_memory_usage = 0 def reset(self): self._consecutive_error_count = 0 self._consecutive_high_memory_usage = 0 self._memory_used_in_percent = 0 def reset_error_info(self): self._consecutive_error_count = 0 def increment_heartbeat_missing_count(self): self._consecutive_error_count += 1 def crossed_error_threshold(self): if (self._consecutive_error_count > 3): return True else: return False def corssed_memory_threshold(self): if (self._consecutive_high_memory_usage > 3): return True else: return False def increment_high_memory_count(self): self._consecutive_high_memory_usage += 1 def reset_high_memory_count(self): self._consecutive_high_memory_usage = 0 def current_status(self): """ Python 2.6 does not support enum. """ if (self._consecutive_error_count == 0 and self._consecutive_high_memory_usage == 0): return "Green" elif (self._consecutive_error_count < 3 and self._consecutive_high_memory_usage < 3): return "Yellow" else: return "Red" class LogFileMarker(object): """ Class to hold omsagent log file marker information. """ def __init__(self): self._last_pos = 0 self._last_crc = "" def reset_marker(self): self._last_pos = 0 self._last_crc = "" class Watcher(object): """ A class that handles periodic monitoring activities. """ def __init__(self, hutil_error, hutil_log): """ Constructor. :param hutil_error: Error logging function (e.g., hutil.error). This is not a stream. :param hutil_log: Normal logging function (e.g., hutil.log). This is not a stream. """ self._hutil_error = hutil_error self._hutil_log = hutil_log self._consecutive_error_count = 0 self._consecutive_restarts_due_to_error = 0 def write_waagent_event(self, event): offset = str(int(time.time() * 1000000)) temp_fn = '/var/lib/waagent/events/'+str(uuid.uuid4()) with open(temp_fn,'w+') as fh: fh.write(event) fn_template = '/var/lib/waagent/events/{}.tld' fn = fn_template.format(offset) while os.path.isfile(fn): offset += 1 fn = fn_template.format(offset) shutil.move(temp_fn, fn) self._hutil_log(fn) def create_telemetry_event(self, operation, operation_success, message, duration): template = """ {{ "eventId": 1, "providerId": "69B669B9-4AF8-4C50-BDC4-6006FA76E975", "parameters": [ {{ "name": "Name", "value": "Microsoft.EnterpriseCloud.Monitoring.OmsAgentForLinux" }}, {{ "name": "Version", "value": \"""" + OMSExtensionVersion + """\" }}, {{ "name": "Operation", "value": "{}" }}, {{ "name": "OperationSuccess", "value": {} }}, {{ "name": "Message", "value": "{}" }}, {{ "name": "Duration", "value": {} }} ] }}""" operation_success_as_string = str(operation_success).lower() formatted_message = message.replace("\n", "\\n").replace("\t", "\\t").replace('"', '\"') return template.format(operation, operation_success_as_string, formatted_message, duration) def upload_telemetry(self): status_files = [ "/var/opt/microsoft/omsagent/log/ODSIngestion.status", "/var/opt/microsoft/omsagent/log/ODSIngestionBlob.status", "/var/opt/microsoft/omsagent/log/ODSIngestionAPI.status", "/var/opt/microsoft/omsconfig/status/dscperformconsistency", "/var/opt/microsoft/omsconfig/status/dscperforminventory", "/var/opt/microsoft/omsconfig/status/dscsetlcm", "/var/opt/microsoft/omsconfig/status/omsconfighost" ] for sf in status_files: if os.path.isfile(sf): mod_time = os.path.getmtime(sf) curr_time = int(time.time()) if (curr_time - mod_time < 300): with open(sf) as json_file: try: status_data = json.load(json_file) operation = status_data["operation"] operation_success = status_data["success"] # Truncating the message to prevent flooding the system message = status_data["message"][:maxMessageSize] event = self.create_telemetry_event(operation,operation_success,message,"300000") self._hutil_log("Writing telemetry event: "+event) self.write_waagent_event(event) self._hutil_log("Successfully processed telemetry status file: "+sf) except Exception: self._hutil_log("Error parsing telemetry status file: "+sf) self._hutil_log("Exception info: "+traceback.format_exc()) if sf.startswith("/var/opt/microsoft/omsconfig/status"): try: self._hutil_log("Cleaning up: " + sf) os.remove(sf) except Exception: self._hutil_log("Error removing telemetry status file: "+ sf) self._hutil_log("Exception info: " + traceback.format_exc()) else: self._hutil_log("Telemetry status file not updated in last 5 mins: "+sf) else: self._hutil_log("Telemetry status file does not exist: "+sf) pass def watch(self): """ Main loop performing various monitoring activities periodically. Currently iterates every 5 minutes, and other periodic activities might be added in the loop later. :return: None """ self._hutil_log('started watcher thread') while True: self._hutil_log('watcher thread waking') self.upload_telemetry() # Sleep 5 minutes self._hutil_log('watcher thread sleeping') time.sleep(60 * 5) pass def monitor_heartbeat(self, self_mon_info, log_file_marker): """ Monitor heartbeat health. OMS output plugin will update the timestamp of new heartbeat file every 5 minutes. We will check if it is updated If not, we will look into omsagent logs and look for specific error logs which indicate we are in non recoverable state. """ take_action = False if (not self.received_heartbeat_recently()): """ We haven't seen heartbeat in more than past 300 seconds """ self_mon_info.increment_heartbeat_missing_count() take_action = False if (self_mon_info.crossed_error_threshold()): # If we do not see heartbeat for last 3 iterations, take corrective action. take_action = True elif (self.check_for_fatal_oms_logs(log_file_marker)): # If we see hearbeat missing and error message, no need to wait for more than one # iteration. It is not a false positive. Take corrective action immediately. take_action = True if (take_action): if (self._consecutive_restarts_due_to_error < 5): self.take_corrective_action(self_mon_info) self._consecutive_restarts_due_to_error += 1 else: self._hutil_error("Last 5 restarts did not help. So we will not restart the agent immediately") # Reset historical infomration. self._consecutive_restarts_due_to_error = 0 self_mon_info.reset_error_info() else: """ If we are able to get the heartbeats, check omsagent logs to identify if there are any error logs. """ self_mon_info.reset_error_info() self._consecutive_restarts_due_to_error = 0 def received_heartbeat_recently(self): heartbeat_file = '/var/opt/microsoft/omsagent/log/ODSIngestion.status' curr_time = int(time.time()) return_val = True file_update_time = curr_time if (os.path.isfile(heartbeat_file)): file_update_time = os.path.getmtime(heartbeat_file) self._hutil_log("File update time={0}, current time={1}".format(file_update_time, curr_time)) else: self._hutil_log("Heartbeat file is not present on the disk.") file_update_time = curr_time - 1000 if (file_update_time + 360 < curr_time): return_val = False else: try: with open(heartbeat_file) as json_file: status_data = json.load(json_file) operation_success = status_data["success"] if (operation_success.lower() == "true"): self._hutil_log("Found success message from ODS Ingestion.") return_val = True else: self._hutil_log("Did not find success message in heart beat file. {0}".format(operation_success)) return_val = False except Exception as e: self._hutil_log("Error parsing ODS Ingestion status file: " + e) # Return True in case we failed to parse the file. We do not want to go into recycle loop in this scenario. return_val = True return return_val def monitor_resource(self, self_mon_info): """ Monitor resource utilization of omsagent. Check for memory and CPU periodically. If they cross the threshold for consecutive 3 iterations we will restart the agent. """ resource_usage = self.get_oms_agent_resource_usage() message = "Memory : {0}, CPU : {1}".format(resource_usage[0], resource_usage[1]) event = self.create_telemetry_event("agenttelemetry","True",message,"300000") self.write_waagent_event(event) self_mon_info._memory_used_in_percent = resource_usage[0] if (self_mon_info._memory_used_in_percent > 0): if (self_mon_info._memory_used_in_percent > MemoryThresholdToWatchFor): # check consecutive memory usage. self_mon_info.increment_high_memory_count() if (self_mon_info.corssed_memory_threshold()): # if we have crossed the memory threshold take corrective action. self.take_corrective_action(self_mon_info) else: self_mon_info.reset_high_memory_count() else: self_mon_info.reset_high_memory_count() def monitor_health(self): """ Role of this function is monitor the health of the oms agent. To begin with it will monitor heartbeats flowing through oms agent. We will also read oms agent logs to determine some error conditions. We don't want to interfare with log watcher function. So we will start this on a new thread. """ self_mon_info = SelfMonitorInfo() log_file_marker = LogFileMarker() # check every 6 minutes. we want to be bit pessimistic while looking for health, especially heartbeats which is emitted every 5 minutes. sleepTime = 6 * 60 # sleep before starting the monitoring. time.sleep(sleepTime) while True: try: # Monitor heartbeat and logs. self.monitor_heartbeat(self_mon_info, log_file_marker) # Monitor memory usage self.monitor_resource(self_mon_info) except IOError as e: self._hutil_error('I/O error in monitoring health of the omsagent. Exception={0}'.format(e)) except Exception as e: self._hutil_error('Error in monitoring health of the omsagent. Exception={0}'.format(e)) finally: time.sleep(sleepTime) def take_corrective_action(self, self_mon_info): """ Take a corrective action. """ run_command_and_log(RestartOMSAgentServiceCommand) self._hutil_log("Successfully restarted OMS linux agent, resetting self mon information.") # Reset self mon information. self_mon_info.reset() def emit_telemetry_after_corrective_action(self): """ TODO : Emit telemetry after taking corrective action. """ def get_total_seconds_from_epoch_for_fluent_logs(self, datetime_string): # fluentd logs timestamp format : 2018-08-02 19:27:34 +0000 # for python 2.7 or earlier there is no good way to convert it into seconds. # so we parse upto seconds, and parse utc specific offset seperately. try: date_time_format = '%Y-%m-%d %H:%M:%S' epoch = datetime(1970, 1, 1) # get hours and minute delta for utc offset. hours_delta_utc = int(datetime_string[21:23]) minutes_delta_utc= int(datetime_string[23:]) log_time = datetime.strptime(datetime_string[:19], date_time_format) + ((timedelta(hours=hours_delta_utc, minutes=minutes_delta_utc)) * (-1 if datetime_string[20] == "+" else 1)) return (log_time - epoch).total_seconds() except Exception as e: self._hutil_error('Error converting timestamp string to seconds. Exception={0}'.format(e)) return 0 def check_for_fatal_oms_logs(self, log_file_marker): """ This function will go through oms log file and check for the logs indicating non recoverable state. That set is hardcoded right now and we can add it to it as we learn more. If we find there is atleast one occurance of such log line from last occurance, we will return True else will return False. """ read_start_time = int(time.time()) if os.path.isfile(OmsAgentLogFile): last_crc = log_file_marker._last_crc last_pos = log_file_marker._last_pos # We do not want to propogate any exception to the caller. try: f = open(OmsAgentLogFile, "r") text = f.readline() # Handle log rotate. Check for CRC of first line of the log file. # Some of the agents like Splunk uses this technique. # If it matches with previous CRC, then file has not changed. # If it is not matching then file has changed and do not seek from # the last_pos rather continue from the begining. if (text != ''): crc = hashlib.sha256(text).hexdigest() self._hutil_log("Last crc = {0}, current crc= {1} position = {2}".format(last_crc, crc, last_pos)) if (last_crc == crc): if (last_pos > 0): f.seek(last_pos) else: self._hutil_log("File has changed do not seek from the offset. current crc = {0}".format(crc)) log_file_marker._last_crc = crc total_lines_read = 1 while True: text = f.readline() if (text == ''): log_file_marker._last_pos = f.tell() break total_lines_read += 1 res = reg_ex.match(text) if res: log_entry_time = self.get_total_seconds_from_epoch_for_fluent_logs(res.group(1)) if (log_entry_time + (10 * 60) < read_start_time): # ignore log line if we are reading logs older than 10 minutes. pass elif (res.group(2) == "warn" or res.group(2) == "error"): for error_statement in ErrorStatements: if (res.group(3) in error_statement): self._hutil_error("Found non recoverable error log in agent log file") # File should be closed in the finally block. return True self._hutil_log("Did not find any non recoverable logs in omsagent log file") except Exception as e: self._hutil_error ("Caught an exception {0}".format(traceback.format_exc())) finally: f.close() else: self._hutil_error ("Omsagent log file not found : {0}".format(OmsAgentLogFile)) return False def get_oms_agent_resource_usage(self): """ If we hit any exception in getting resoource usage of the omsagent return 0,0 We need not crash/fail in this case. return tuple : memory, cpu. Long run for north star we should use cgroups. cgroups tools are not available by default on all the distros and we would need to package with the agent those and use. Also at this point it is not very clear if customers would want us to create cgroups on their vms. """ try: mem_usage = 0.0 cpu_usage = 0.0 with open(OmsAgentPidFile, 'r') as infile: pid = infile.readline() # Get pid of omsagent process. # top output: # $1 - PID, # $2 - account, # $9 - CPU, # $10 - Memory, # $12 - Process name out = subprocess.Popen('top -bn1 | grep -i omsagent | awk \'{print $1 " " $2 " " $9 " " $10 " " $12}\'', shell=True, stdout=subprocess.PIPE) for line in out.stdout: s = line.split() if (len(s) >= 4 and s[0] == pid and s[1] == 'omsagent' and s[4] == 'omsagent'): return float(s[3]) , float(s[2]) except Exception as e: self._hutil_error('Error getting memory usage for omsagent process. Exception={0}'.format(e)) # Control will reach here only in case of error condition. In that case it is ok to return 0 as it is harmless to be cautious. return mem_usage, cpu_usage ================================================ FILE: RDMAUpdate/MANIFEST.in ================================================ include HandlerManifest.json handler.py prune test ================================================ FILE: RDMAUpdate/RDMAUpdate.pyproj ================================================ <?xml version="1.0" encoding="utf-8"?> <Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0"> <PropertyGroup> <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> <SchemaVersion>2.0</SchemaVersion> <ProjectGuid>7883b2c9-5431-4fac-bfca-c92b8a17644a</ProjectGuid> <ProjectHome>.</ProjectHome> <StartupFile>RDMAUpdate.py</StartupFile> <SearchPath> </SearchPath> <WorkingDirectory>.</WorkingDirectory> <OutputPath>.</OutputPath> <Name>RDMAUpdate</Name> <RootNamespace>RDMAUpdate</RootNamespace> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Debug' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <PropertyGroup Condition=" '$(Configuration)' == 'Release' "> <DebugSymbols>true</DebugSymbols> <EnableUnmanagedDebugging>false</EnableUnmanagedDebugging> </PropertyGroup> <ItemGroup> <Compile Include="RDMAUpdate.py" /> </ItemGroup> <PropertyGroup> <VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion> <PtvsTargetsFile>$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile> </PropertyGroup> <Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" /> <Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" /> <!-- Uncomment the CoreCompile target to enable the Build command in Visual Studio and specify your pre- and post-build commands in the BeforeBuild and AfterBuild targets below. --> <!--<Target Name="CoreCompile" />--> <Target Name="BeforeBuild"> </Target> <Target Name="AfterBuild"> </Target> </Project> ================================================ FILE: RDMAUpdate/README.txt ================================================ ================================================ FILE: RDMAUpdate/enableit.js ================================================  ================================================ FILE: RDMAUpdate/main/CommandExecuter.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import os import os.path import shlex import sys from subprocess import * class CommandExecuter(object): """description of class""" def __init__(self, logger): self.logger = logger def Execute(self, command_to_execute): self.logger.log("Executing:" + command_to_execute) args = shlex.split(command_to_execute) proc = Popen(args) returnCode = proc.wait() return returnCode def RunGetOutput(self, command_to_execute): try: output=subprocess.check_output(command_to_execute,stderr=subprocess.STDOUT,shell=True) return 0,output.decode('latin-1') except subprocess.CalledProcessError as e : self.logger.log('CalledProcessError. Error Code is ' + str(e.returncode) ) self.logger.log('CalledProcessError. Command string was ' + e.cmd ) self.logger.log('CalledProcessError. Command result was ' + (e.output[:-1]).decode('latin-1')) return e.returncode,e.output.decode('latin-1') ================================================ FILE: RDMAUpdate/main/Common.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class CommonVariables: azure_path = 'main/azure' utils_path_name = 'Utils' extension_name = 'RDMAUpdateForLinux' extension_version = "0.1.0.8" extension_type = extension_name extension_media_link = 'https://andliu.blob.core.windows.net/extensions/' + extension_name + '-' + str(extension_version) + '.zip' extension_label = 'Windows Azure RDMA Update Extension for Linux IaaS' extension_description = extension_label """ configurations """ wrapper_package_name = 'msft-rdma-drivers' """ error code definitions """ process_success = 0 common_failed = 1 install_hv_utils_failed = 2 nd_driver_detect_error = 3 driver_version_not_found = 4 unknown_error = 5 package_not_found = 6 package_install_failed = 7 """ logs related """ InfoLevel = 'Info' WarningLevel = 'Warning' ErrorLevel = 'Error' """ check_rdma_result """ UpToDate = 0 OutOfDate = 1 DriverVersionNotFound = 3 Unknown = -1 ================================================ FILE: RDMAUpdate/main/CronUtil.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import sys from Utils import HandlerUtil from CommandExecuter import CommandExecuter from Common import CommonVariables class CronUtil(object): """description of class""" def __init__(self,logger): self.logger = logger self.crontab = '/etc/crontab' self.cron_restart_cmd = 'service cron restart' def check_update_cron_config(self): script_file_path = os.path.realpath(sys.argv[0]) script_dir = os.path.dirname(script_file_path) script_file = os.path.basename(script_file_path) old_line_end = ' '.join([script_file, '-chkrdma']) new_line = ' '.join(['\n0 0 * * *', 'root cd', script_dir + "/..", '&& python main/handle.py -chkrdma >/dev/null 2>&1\n']) HandlerUtil.waagent.ReplaceFileContentsAtomic(self.crontab, \ '\n'.join(filter(lambda a: a and (old_line_end not in a), HandlerUtil.waagent.GetFileContents(self.crontab).split('\n')))+ new_line) def restart_cron(self): commandExecuter = CommandExecuter(self.logger) returnCode = commandExecuter.Execute(self.cron_restart_cmd) if(returnCode != CommonVariables.process_success): self.logger.log(msg="",level=CommonVariables.ErrorLevel) ================================================ FILE: RDMAUpdate/main/RDMALogger.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime class RDMALogger(object): def __init__(self, hutil): self.msg = '' self.hutil = hutil """description of class""" def log(self, msg, level='Info'): log_msg = (str(datetime.datetime.now()) + ' ' + level + ' ' + msg + '\n') self.hutil.log(log_msg) ================================================ FILE: RDMAUpdate/main/RdmaException.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class RdmaException(Exception): def __init__(self, value): self.value = value def __str__(self): return repr(self.value) ================================================ FILE: RDMAUpdate/main/SecondStageMarkConfig.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os class SecondStageMarkConfig(object): """description of class""" def __init__(self): self.mark_file_path = './second_stage_mark_FD76C85E-406F-4CFA-8EB0-CF18B123365C' def MarkIt(self): with open(self.mark_file_path,'w') as file: file.write('marked') def IsMarked(self): return os.path.exists(self.mark_file_path) def ClearIt(self): if(self.IsMarked()): os.remove(self.mark_file_path) else: pass ================================================ FILE: RDMAUpdate/main/Utils/HandlerUtil.py ================================================ # # Handler library for Linux IaaS # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "<your log folder location>", "configFolder": "<your config folder location>", "statusFolder": "<your status folder location>", "heartbeatFile": "<your heartbeat file location>", } }] Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Example Status Report: [{"version":"1.0","timestampUTC":"2014-05-29T04:20:13Z","status":{"name":"Chef Extension Handler","operation":"chef-client-run","status":"success","code":0,"formattedMessage":{"lang":"en-US","message":"Chef-client run success"}}}] """ import base64 import os import os.path import sys import json import time import tempfile from os.path import join from Utils.WAAgentUtil import waagent from waagent import LoggerInit import logging import logging.handlers DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ" class HandlerContext: def __init__(self,name): self._name = name self._version = '0.0' return class HandlerUtility: def __init__(self, log, error, short_name): self._log = log self._error = error self._short_name = short_name self.syslogger = logging.getLogger(self._short_name) self.syslogger.setLevel(logging.INFO) handler = logging.handlers.SysLogHandler(address='/dev/log') formatter = logging.Formatter('%(name)s: %(levelname)s %(message)s') handler.setFormatter(formatter) self.syslogger.addHandler(handler) def _get_log_prefix(self): return '[%s-%s]' % (self._context._name, self._context._version) def _get_current_seq_no(self, config_folder): seq_no = -1 cur_seq_no = -1 freshest_time = None for subdir, dirs, files in os.walk(config_folder): for file in files: try: cur_seq_no = int(os.path.basename(file).split('.')[0]) if(freshest_time == None): freshest_time = os.path.getmtime(join(config_folder,file)) seq_no = cur_seq_no else: current_file_m_time = os.path.getmtime(join(config_folder,file)) if(current_file_m_time > freshest_time): freshest_time = current_file_m_time seq_no = cur_seq_no except ValueError: continue return seq_no def log(self, message): self._log(self._get_log_prefix() + message) def error(self, message): self._error(self._get_log_prefix() + message) def syslog(self, level, message): if level == logging.INFO: self.syslogger.info(message) elif level == logging.WARNING: self.syslogger.warning(message) elif level == logging.ERROR: self.syslogger.error(message) def log_and_syslog(self, level, message): self.syslog(level, message) if level == logging.INFO: self.log(message) elif level == logging.WARNING: self.log(" ".join(["Warning:", message])) elif level == logging.ERROR: self.error(message) def _parse_config(self, ctxt): config = None try: config = json.loads(ctxt) except: self.error('JSON exception decoding ' + ctxt) if config is None: self.error("JSON error processing settings file:" + ctxt) else: handlerSettings = config['runtimeSettings'][0]['handlerSettings'] if handlerSettings.has_key('protectedSettings') and \ handlerSettings.has_key("protectedSettingsCertThumbprint") and \ handlerSettings['protectedSettings'] is not None and \ handlerSettings["protectedSettingsCertThumbprint"] is not None: protectedSettings = handlerSettings['protectedSettings'] thumb = handlerSettings['protectedSettingsCertThumbprint'] cert = waagent.LibDir + '/' + thumb + '.crt' pkey = waagent.LibDir + '/' + thumb + '.prv' unencodedSettings = base64.standard_b64decode(protectedSettings) openSSLcmd = "openssl smime -inform DER -decrypt -recip {0} -inkey {1}" cleartxt = waagent.RunSendStdin(openSSLcmd.format(cert, pkey), unencodedSettings)[1] if cleartxt is None: self.error("OpenSSL decode error using thumbprint " + thumb) self.do_exit(1, "Enable", 'error', '1', 'Failed to decrypt protectedSettings') jctxt = '' try: jctxt = json.loads(cleartxt) except: self.error('JSON exception decoding ' + cleartxt) handlerSettings['protectedSettings'] = jctxt self.log('Config decoded correctly.') return config def do_parse_context(self,operation): _context = self.try_parse_context() if not _context: self.do_exit(1,operation,'error','1', operation + ' Failed') return _context def try_parse_context(self): self._context = HandlerContext(self._short_name) handler_env = None config = None ctxt = None code = 0 # get the HandlerEnvironment.json. According to the extension handler # spec, it is always in the ./ directory self.log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("Unable to locate " + handler_env_file) return None ctxt = waagent.GetFileContents(handler_env_file) if ctxt == None : self.error("Unable to read " + handler_env_file) try: handler_env = json.loads(ctxt) except: pass if handler_env == None : self.log("JSON error processing " + handler_env_file) return None if type(handler_env) == list: handler_env = handler_env[0] self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context._log_dir = handler_env['handlerEnvironment']['logFolder'] self._context._log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'],'extension.log') self._change_log_file() self._context._status_dir = handler_env['handlerEnvironment']['statusFolder'] self._context._heartbeat_file = handler_env['handlerEnvironment']['heartbeatFile'] self._context._seq_no = self._get_current_seq_no(self._context._config_dir) if self._context._seq_no < 0: self.error("Unable to locate a .settings file!") return None self._context._seq_no = str(self._context._seq_no) self.log('sequence number is ' + self._context._seq_no) self._context._status_file = os.path.join(self._context._status_dir, self._context._seq_no + '.status') self._context._settings_file = os.path.join(self._context._config_dir, self._context._seq_no + '.settings') self.log("setting file path is" + self._context._settings_file) ctxt = None ctxt = waagent.GetFileContents(self._context._settings_file) if ctxt == None : error_msg = 'Unable to read ' + self._context._settings_file + '. ' self.error(error_msg) return None self.log("JSON config: " + ctxt) self._context._config = self._parse_config(ctxt) return self._context def _change_log_file(self): self.log("Change log file to " + self._context._log_file) LoggerInit(self._context._log_file,'/dev/stdout') self._log = waagent.Log self._error = waagent.Error def set_verbose_log(self, verbose): if(verbose == "1" or verbose == 1): self.log("Enable verbose log") LoggerInit(self._context._log_file, '/dev/stdout', verbose=True) else: self.log("Disable verbose log") LoggerInit(self._context._log_file, '/dev/stdout', verbose=False) def is_seq_smaller(self): return int(self._context._seq_no) <= self._get_most_recent_seq() def save_seq(self): self._set_most_recent_seq(self._context._seq_no) self.log("set most recent sequence number to " + self._context._seq_no) def exit_if_enabled(self): self.exit_if_seq_smaller() def exit_if_seq_smaller(self): if(self.is_seq_smaller()): self.log("Current sequence number, " + self._context._seq_no + ", is not greater than the sequence number of the most recent executed configuration. Exiting...") sys.exit(0) self.save_seq() def _get_most_recent_seq(self): if(os.path.isfile('mrseq')): seq = waagent.GetFileContents('mrseq') if(seq): return int(seq) return -1 def is_current_config_seq_greater_inused(self): return int(self._context._seq_no) > self._get_most_recent_seq() def get_inused_config_seq(self): return self._get_most_recent_seq() def set_inused_config_seq(self,seq): self._set_most_recent_seq(seq) def _set_most_recent_seq(self,seq): waagent.SetFileContents('mrseq', str(seq)) def do_status_report(self, operation, status, status_code, message): self.log("{0},{1},{2},{3}".format(operation, status, status_code, message)) tstamp = time.strftime(DateTimeFormat, time.gmtime()) stat = [{ "version" : self._context._version, "timestampUTC" : tstamp, "status" : { "name" : self._context._name, "operation" : operation, "status" : status, "code" : status_code, "formattedMessage" : { "lang" : "en-US", "message" : message } } }] stat_rept = json.dumps(stat) if self._context._status_file: with open(self._context._status_file,'w+') as f: f.write(stat_rept) def do_heartbeat_report(self, heartbeat_file,status,code,message): # heartbeat health_report = '[{"version":"1.0","heartbeat":{"status":"' + status + '","code":"' + code + '","Message":"' + message + '"}}]' if waagent.SetFileContents(heartbeat_file,health_report) == None : self.error('Unable to wite heartbeat info to ' + heartbeat_file) def do_exit(self,exit_code,operation,status,code,message): try: self.do_status_report(operation, status,code,message) except Exception as e: self.log("Can't update status: " + str(e)) sys.exit(exit_code) def get_name(self): return self._context._name def get_seq_no(self): return self._context._seq_no def get_log_dir(self): return self._context._log_dir def get_handler_settings(self): return self._context._config['runtimeSettings'][0]['handlerSettings'] def get_protected_settings(self): return self.get_handler_settings().get('protectedSettings') def get_public_settings(self): return self.get_handler_settings().get('publicSettings') ================================================ FILE: RDMAUpdate/main/Utils/WAAgentUtil.py ================================================ # Wrapper module for waagent # # waagent is not written as a module. This wrapper module is created # to use the waagent code as a module. # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import imp import os import os.path # # The following code will search and load waagent code and expose # it as a submodule of current module # def searchWAAgent(): agentPath = '/usr/sbin/waagent' if(os.path.isfile(agentPath)): return agentPath user_paths = os.environ['PYTHONPATH'].split(os.pathsep) for user_path in user_paths: agentPath = os.path.join(user_path, 'waagent') if(os.path.isfile(agentPath)): return agentPath return None agentPath = searchWAAgent() if(agentPath): waagent = imp.load_source('waagent', agentPath) else: raise Exception("Can't load waagent.") if not hasattr(waagent, "AddExtensionEvent"): """ If AddExtensionEvent is not defined, provide a dummy impl. """ def _AddExtensionEvent(*args, **kwargs): pass waagent.AddExtensionEvent = _AddExtensionEvent if not hasattr(waagent, "WALAEventOperation"): class _WALAEventOperation: HeartBeat = "HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" waagent.WALAEventOperation = _WALAEventOperation __ExtensionName__ = None def InitExtensionEventLog(name): __ExtensionName__ = name def AddExtensionEvent(name=__ExtensionName__, op=waagent.WALAEventOperation.Enable, isSuccess=False, message=None): if name is not None: waagent.AddExtensionEvent(name=name, op=op, isSuccess=isSuccess, message=message) ================================================ FILE: RDMAUpdate/main/Utils/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: RDMAUpdate/main/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: RDMAUpdate/main/handle.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import array import base64 import os import os.path import re import json import string import subprocess import sys import imp import time import shlex import traceback import httplib import xml.parsers.expat import datetime from patch import * from os.path import join from Common import CommonVariables from Utils import HandlerUtil from urlparse import urlparse from RDMALogger import RDMALogger from CronUtil import * from SecondStageMarkConfig import SecondStageMarkConfig def main(): global logger global hutil global MyPatching HandlerUtil.LoggerInit('/var/log/waagent.log','/dev/stdout') HandlerUtil.waagent.Log("%s started to handle." % (CommonVariables.extension_name)) hutil = HandlerUtil.HandlerUtility(HandlerUtil.waagent.Log, HandlerUtil.waagent.Error, CommonVariables.extension_name) logger = RDMALogger(hutil) MyPatching = GetMyPatching(logger) hutil.patching = MyPatching for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(rdmaupdate)", a): rdmaupdate() elif re.match("^([-/]*)(chkrdma)", a): chkrdma() def chkrdma(): hutil.do_parse_context('Executing') check_result = MyPatching.check_rdma() if(check_result == CommonVariables.UpToDate): hutil.do_exit(0, 'Enable','success','0', 'RDMA Driver up to date.') if(check_result == CommonVariables.OutOfDate): hutil.do_exit(0, 'Enable','success','0', 'RDMA Driver out of date.') if(check_result == CommonVariables.DriverVersionNotFound): hutil.do_exit(0, 'Enable','success','0', 'RDMA Driver not found.') if(check_result == CommonVariables.Unknown): hutil.do_exit(0, 'Enable','success','0', 'RDMA version not found.') def rdmaupdate(): hutil.do_parse_context('Executing') try: MyPatching.rdmaupdate() hutil.do_status_report('Enable','success','0', 'Enable Succeeded') MyPatching.reboot_machine() except Exception as e: logger.log("Failed to update with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) hutil.do_exit(0, 'Enable','success','0','enable failed, please take a look at the extension log.') def start_daemon(): args = [os.path.join(os.getcwd(), __file__), "-rdmaupdate"] logger.log("start_daemon with args:" + str(args)) devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) def enable(): # do it one time when enabling. # config the cron job hutil.do_parse_context('Enable') secondStageMarkConfig = SecondStageMarkConfig() if(secondStageMarkConfig.IsMarked()): secondStageMarkConfig.ClearIt() start_daemon() else: hutil.exit_if_enabled() cronUtil = CronUtil(logger) cronUtil.check_update_cron_config() cronUtil.restart_cron() start_daemon() def install(): hutil.do_parse_context('Install') hutil.do_exit(0, 'Install','success','0', 'Install Succeeded') def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0,'Uninstall','success','0', 'Uninstall succeeded') def disable(): hutil.do_parse_context('Disable') hutil.do_exit(0,'Disable','success','0', 'Disable Succeeded') def update(): hutil.do_parse_context('Upadate') hutil.do_exit(0,'Update','success','0', 'Update Succeeded') if __name__ == '__main__' : main() ================================================ FILE: RDMAUpdate/main/patch/AbstractPatching.py ================================================ #!/usr/bin/python # # AbstractPatching is the base patching class of all the linux distros # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess class AbstractPatching(object): """ AbstractPatching defines a skeleton neccesary for a concrete Patching class. """ def __init__(self,distro_info): self.distro_info = distro_info self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def CreateCronJob(self): pass ================================================ FILE: RDMAUpdate/main/patch/OraclePatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from redhatPatching import redhatPatching from Common import * class OraclePatching(redhatPatching): def __init__(self,logger,distro_info): super(OraclePatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' #def install_extras(self): # common_extras = ['cryptsetup','lsscsi'] # for extra in common_extras: # self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) #if(paras.filesystem == "btrfs"): # extras = ['btrfs-tools'] # for extra in extras: # print("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) #pass ================================================ FILE: RDMAUpdate/main/patch/SuSEPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * from CommandExecuter import CommandExecuter from RdmaException import RdmaException from SecondStageMarkConfig import SecondStageMarkConfig class SuSEPatching(AbstractPatching): def __init__(self,logger,distro_info): super(SuSEPatching,self).__init__(distro_info) self.logger = logger if(distro_info[1] == "11"): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cryptsetup_path = '/sbin/cryptsetup' self.cat_path = '/bin/cat' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.modprobe_path = '/usr/bin/modprobe' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.ps_path = '/bin/ps' self.resize2fs_path = '/sbin/resize2fs' self.reboot_path = '/sbin/reboot' self.rmmod_path = '/sbin/rmmod' self.service_path='/usr/sbin/service' self.umount_path = '/bin/umount' self.zypper_path = '/usr/bin/zypper' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.modprobe_path = '/usr/sbin/modprobe' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.ps_path = '/usr/bin/ps' self.resize2fs_path = '/sbin/resize2fs' self.reboot_path = '/sbin/reboot' self.rmmod_path = '/usr/sbin/rmmod' self.service_path = '/usr/sbin/service' self.umount_path = '/usr/bin/umount' self.zypper_path = '/usr/bin/zypper' def rdmaupdate(self): check_install_result = self.check_install_hv_utils() if(check_install_result == CommonVariables.process_success): time.sleep(40) check_result = self.check_rdma() if(check_result == CommonVariables.UpToDate): return elif(check_result == CommonVariables.OutOfDate): nd_driver_version = self.get_nd_driver_version() rdma_package_installed_version = self.get_rdma_package_version() update_rdma_driver_result = self.update_rdma_driver(nd_driver_version, rdma_package_installed_version) elif(check_result == CommonVariables.DriverVersionNotFound): raise RdmaException(CommonVariables.driver_version_not_found) elif(check_result == CommonVariables.Unknown): raise RdmaException(CommonVariables.unknown_error) else: raise RdmaException(CommonVariables.install_hv_utils_failed) def check_rdma(self): nd_driver_version = self.get_nd_driver_version() if(nd_driver_version is None or nd_driver_version == ""): return CommonVariables.DriverVersionNotFound package_version = self.get_rdma_package_version() if(package_version is None or package_version == ""): return CommonVariables.OutOfDate else: # package_version would be like this :20150707_k3.12.28_4-3.1 # nd_driver_version 140.0 self.logger.log("nd_driver_version is " + str(nd_driver_version) + " package_version is " + str(package_version)) if(nd_driver_version is not None): r = re.match(".+(%s)$" % nd_driver_version, package_version)# NdDriverVersion should be at the end of package version if not r : #host ND version is the same as the package version, do an update return CommonVariables.OutOfDate else: return CommonVariables.UpToDate return CommonVariables.Unknown def reload_hv_utils(self): commandExecuter = CommandExecuter(self.logger) #clear /run/hv_kvp_daemon folder for the service could not be restart walkaround error,output = commandExecuter.RunGetOutput(self.rmmod_path + " hv_utils") #find a way to force install non-prompt self.logger.log("rmmod hv_utils return code: " + str(error) + " output:" + str(output)) if(error != CommonVariables.process_success): return CommonVariables.common_failed error,output = commandExecuter.RunGetOutput(self.modprobe_path + " hv_utils") #find a way to force install non-prompt self.logger.log("modprobe hv_utils return code: " + str(error) + " output:" + str(output)) if(error != CommonVariables.process_success): return CommonVariables.common_failed return CommonVariables.process_success def restart_hv_kvp_daemon(self): commandExecuter = CommandExecuter(self.logger) reload_result = self.reload_hv_utils() if(reload_result == CommonVariables.process_success): if(os.path.exists('/run/hv_kvp_daemon')): os.rmdir('/run/hv_kvp_daemon') error,output = commandExecuter.RunGetOutput(self.service_path + " hv_kvp_daemon start") #find a way to force install non-prompt self.logger.log("service hv_kvp_daemon start return code: " + str(error) + " output:" + str(output)) if(error != CommonVariables.process_success): return CommonVariables.common_failed return CommonVariables.process_success else: return CommonVariables.common_failed def check_install_hv_utils(self): commandExecuter = CommandExecuter(self.logger) error, output = commandExecuter.RunGetOutput(self.ps_path + " -ef") if(error != CommonVariables.process_success): return CommonVariables.common_failed else: r = re.search("hv_kvp_daemon", output) if r is None : self.logger.log("KVP deamon is not running, install it") error,output = commandExecuter.RunGetOutput(self.zypper_path + " -n install --force hyper-v") self.logger.log("install hyper-v return code: " + str(error) + " output:" + str(output)) if(error != CommonVariables.process_success): return CommonVariables.common_failed secondStageMarkConfig = SecondStageMarkConfig() secondStageMarkConfig.MarkIt() self.reboot_machine() return CommonVariables.process_success else : self.logger.log("KVP deamon is running") return CommonVariables.process_success def get_nd_driver_version(self): """ if error happens, raise a RdmaException """ try: with open("/var/lib/hyperv/.kvp_pool_0", "r") as f: lines = f.read() r = re.search("NdDriverVersion\0+(\d\d\d\.\d)", lines) if r is not None: NdDriverVersion = r.groups()[0] return NdDriverVersion #e.g. NdDriverVersion = 142.0 else : self.logger.log("Error: NdDriverVersion not found.") return None except Exception as e: errMsg = 'Failed to enable the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log("Can't update status: " + errMsg) raise RdmaException(CommonVariables.nd_driver_detect_error) def get_rdma_package_version(self): """ """ commandExecuter = CommandExecuter(self.logger) error, output = commandExecuter.RunGetOutput(self.zypper_path + " info msft-lis-rdma-kmp-default") if(error == CommonVariables.process_success): r = re.search("Version: (\S+)", output) if r is not None: package_version = r.groups()[0]# e.g. package_version is "20150707_k3.12.28_4-3.1.140.0" return package_version else: return None else: return None def update_rdma_driver(self, host_version, rdma_package_installed_version): """ """ commandExecuter = CommandExecuter(self.logger) error, output = commandExecuter.RunGetOutput(self.zypper_path + " lr -u") rdma_pack_result = re.search("msft-rdma-pack", output) if rdma_pack_result is None : self.logger.log("rdma_pack_result is None") error, output = commandExecuter.RunGetOutput(self.zypper_path + " ar https://drivers.suse.com/microsoft/Microsoft-LIS-RDMA/sle-12/updates msft-rdma-pack") #wait for the cache build. time.sleep(20) self.logger.log("error result is " + str(error) + " output is : " + str(output)) else: self.logger.log("output is: "+str(output)) self.logger.log("msft-rdma-pack found") returnCode,message = commandExecuter.RunGetOutput(self.zypper_path + " --no-gpg-checks refresh") self.logger.log("refresh repro return code is " + str(returnCode) + " output is: " + str(message)) #install the wrapper package, that will put the driver RPM packages under /opt/microsoft/rdma returnCode,message = commandExecuter.RunGetOutput(self.zypper_path + " -n remove " + CommonVariables.wrapper_package_name) self.logger.log("remove wrapper package return code is " + str(returnCode) + " output is: " + str(message)) returnCode,message = commandExecuter.RunGetOutput(self.zypper_path + " --non-interactive install --force " + CommonVariables.wrapper_package_name) self.logger.log("install wrapper package return code is " + str(returnCode) + " output is: " + str(message)) r = os.listdir("/opt/microsoft/rdma") if r is not None : for filename in r : if re.match("msft-lis-rdma-kmp-default-\d{8}\.(%s).+" % host_version, filename) : error,output = commandExecuter.RunGetOutput(self.zypper_path + " --non-interactive remove msft-lis-rdma-kmp-default") self.logger.log("remove msft-lis-rdma-kmp-default result is " + str(error) + " output is: " + str(output)) self.logger.log("Installing RPM /opt/microsoft/rdma/" + filename) error,output = commandExecuter.RunGetOutput(self.zypper_path + " --non-interactive install --force /opt/microsoft/rdma/%s" % filename) self.logger.log("Install msft-lis-rdma-kmp-default result is " + str(error) + " output is: " + str(output)) if(error == CommonVariables.process_success): self.reboot_machine() else: raise RdmaException(CommonVariables.package_install_failed) else: self.logger.log("RDMA drivers not found in /opt/microsoft/rdma") raise RdmaException(CommonVariables.package_not_found) def reboot_machine(self): self.logger.log("rebooting machine") commandExecuter = CommandExecuter(self.logger) commandExecuter.RunGetOutput(self.reboot_path) ================================================ FILE: RDMAUpdate/main/patch/UbuntuPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * class UbuntuPatching(AbstractPatching): def __init__(self,logger,distro_info): super(UbuntuPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' #def install_extras(self): # """ # install the sg_dd because the default dd do not support the sparse write # """ # if(self.distro_info[0].lower() == "ubuntu" and self.distro_info[1] == "12.04"): # common_extras = ['cryptsetup-bin','lsscsi'] # else: # common_extras = ['cryptsetup-bin','lsscsi'] # for extra in common_extras: # self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['apt-get', 'install','-y', extra]))) ================================================ FILE: RDMAUpdate/main/patch/__init__.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import re import platform from UbuntuPatching import UbuntuPatching from redhatPatching import redhatPatching from centosPatching import centosPatching from OraclePatching import OraclePatching from SuSEPatching import SuSEPatching # Define the function in case waagent(<2.0.4) doesn't have DistInfo() def DistInfo(): if 'FreeBSD' in platform.system(): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if os.path.isfile('/etc/oracle-release'): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['Oracle', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=0)) # remove trailing whitespace in distro name distinfo[0] = distinfo[0].strip() return distinfo else: return platform.dist() def GetMyPatching(logger): """ Return MyPatching object. NOTE: Logging is not initialized at this point. """ dist_info = DistInfo() if 'Linux' in platform.system(): Distro = dist_info[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() Distro = Distro.strip('"') Distro = Distro.strip(' ') patching_class_name = Distro + 'Patching' if not globals().has_key(patching_class_name): print Distro + ' is not a supported distribution.' return None patchingInstance = globals()[patching_class_name](logger,dist_info) return patchingInstance ================================================ FILE: RDMAUpdate/main/patch/centosPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from redhatPatching import redhatPatching from Common import * class centosPatching(redhatPatching): def __init__(self,logger,distro_info): super(centosPatching,self).__init__(logger,distro_info) self.logger = logger if(distro_info[1] == "6.8" or distro_info[1] == "6.7" or distro_info[1] == "6.6" or distro_info[1] == "6.5"): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' #def install_extras(self): # common_extras = ['cryptsetup','lsscsi'] # for extra in common_extras: # self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) ================================================ FILE: RDMAUpdate/main/patch/redhatPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * class redhatPatching(AbstractPatching): def __init__(self,logger,distro_info): super(redhatPatching,self).__init__(distro_info) self.logger = logger if(distro_info[1] == "6.7"): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' #def install_extras(self): # common_extras = ['cryptsetup','lsscsi'] # for extra in common_extras: # self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) ================================================ FILE: RDMAUpdate/references ================================================ Utils/ ================================================ FILE: RDMAUpdate/setup.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # To build: # python setup.py sdist # # To install: # python setup.py install # # To register (only needed once): # python setup.py register # # To upload: # python setup.py sdist upload from distutils.core import setup import os import json import subprocess from zipfile import ZipFile from main.Common import CommonVariables packages_array = [] main_folder = 'main' main_entry = main_folder + '/handle.py' packages_array.append(main_folder) patch_folder = main_folder + '/patch' packages_array.append(patch_folder) """ copy the dependency to the local """ """ copy the utils lib to local """ target_utils_path = main_folder + '/' + CommonVariables.utils_path_name packages_array.append(target_utils_path) """ generate the HandlerManifest.json file. """ manifest_obj = [{ "name": CommonVariables.extension_name, "version": CommonVariables.extension_version, "handlerManifest": { "installCommand": main_entry + " -install", "uninstallCommand": main_entry + " -uninstall", "updateCommand": main_entry + " -update", "enableCommand": main_entry + " -enable", "disableCommand": main_entry + " -disable", "rebootAfterInstall": False, "reportHeartbeat": False } }] manifest_str = json.dumps(manifest_obj, sort_keys = True, indent = 4) manifest_file = open("HandlerManifest.json", "w") manifest_file.write(manifest_str) manifest_file.close() """ generate the extension xml file """ extension_xml_file_content = """<ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure"> <ProviderNameSpace>Microsoft.OSTCExtensions</ProviderNameSpace> <Type>%s</Type> <Version>%s</Version> <Label>%s</Label> <HostingResources>VmRole</HostingResources> <MediaLink>%s</MediaLink> <Description>%s</Description> <IsInternalExtension>true</IsInternalExtension> <Eula>https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt</Eula> <PrivacyUri>https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt</PrivacyUri> <HomepageUri>https://github.com/Azure/azure-linux-extensions</HomepageUri> <IsJsonExtension>true</IsJsonExtension> <CompanyName>Microsoft Open Source Technology Center</CompanyName> </ExtensionImage>""" % (CommonVariables.extension_type,CommonVariables.extension_version,CommonVariables.extension_label,CommonVariables.extension_media_link,CommonVariables.extension_description) extension_xml_file = open(CommonVariables.extension_name + '-' + str(CommonVariables.extension_version) + '.xml', 'w') extension_xml_file.write(extension_xml_file_content) extension_xml_file.close() """ setup script, to package the files up """ setup(name = CommonVariables.extension_name, version = CommonVariables.extension_version, description=CommonVariables.extension_description, license='Apache License 2.0', author='Microsoft Corporation', author_email='andliu@microsoft.com', url='https://github.com/Azure/azure-linux-extensions', classifiers = ['Development Status :: 5 - Production/Stable', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'License :: OSI Approved :: Apache Software License'], packages = packages_array) """ unzip the package files and re-package it. """ target_zip_file_location = './dist/' target_folder_name = CommonVariables.extension_name + '-' + str(CommonVariables.extension_version) target_zip_file_path = target_zip_file_location + target_folder_name + '.zip' target_zip_file = ZipFile(target_zip_file_path) target_zip_file.extractall(target_zip_file_location) def dos2unix(src): args = ["dos2unix",src] devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) print('dos2unix %s ' % (src)) child.wait() def zip(src, dst): zf = ZipFile("%s" % (dst), "w") abs_src = os.path.abspath(src) for dirname, subdirs, files in os.walk(src): for filename in files: absname = os.path.abspath(os.path.join(dirname, filename)) dos2unix(absname) arcname = absname[len(abs_src) + 1:] print('zipping %s as %s' % (os.path.join(dirname, filename), arcname)) zf.write(absname, arcname) zf.close() final_folder_path = target_zip_file_location + target_folder_name zip(final_folder_path, target_zip_file_path) ================================================ FILE: RDMAUpdate/test/update_rdma_driver.py ================================================ import subprocess, re, os def RunGetOutput(cmd,chk_err=True): try: output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) except subprocess.CalledProcessError,e : if chk_err : print('CalledProcessError. Error Code is ' + str(e.returncode) ) print('CalledProcessError. Command string was ' + e.cmd ) print('CalledProcessError. Command result was ' + (e.output[:-1]).decode('latin-1')) return e.returncode,e.output.decode('latin-1') return 0,output.decode('latin-1') #def def InstallRDMADriver(host_version) : #make sure we have the correct repo configured error, output = RunGetOutput("zypper lr -u") if not re.search("msft-rdma-pack", output) : RunGetOutput("zypper ar https://drivers.suse.com/microsoft/Microsoft-LIS-RDMA/sle-12/updates msft-rdma-pack") #install the wrapper package, that will put the driver RPM packages under /opt/microsoft/rdma RunGetOutput("zypper --non-interactive install --force msft-rdma-drivers") #install the driver RPM package r = os.listdir("/opt/microsoft/rdma") if r : for filename in r : if re.match("msft-lis-rdma-kmp-default-\d{8}\.(%s).+" % host_version, filename) : print "Installing RPM /opt/microsoft/rdma/" + filename RunGetOutput("zypper --non-interactive install --force /opt/microsoft/rdma/%s" % filename) return print "RDMA drivers not found in /opt/microsoft/rdma" #def #1. check if kvp daemon is running, if not install it and reboot error, output = RunGetOutput("ps -ef") # how about error != 0 r = re.search("hv_kvp_daemon", output) if not r : print "KVP deamon is not running, install it" RunGetOutput("zypper --non-interactive install --force hyper-v") RunGetOutput("reboot") else : print "KVP deamon is running" #2. get the host ND version f = open("/var/lib/hyperv/.kvp_pool_0", "r") lines = f.read(); f.close() r = re.match("NdDriverVersion\0+(\d\d\d\.\d)", lines) if r : NdDriverVersion = r.groups()[0] print "ND version = " + NdDriverVersion #e.g. NdDriverVersion = 142.0 else : print "Error: NdDriverVersion not found. Abort" exit() #3. if the ND version doesn't match the RDMA driver package version, do an update error, output = RunGetOutput("zypper --non-interactive info msft-lis-rdma-kmp-default") r = re.search("Version:\s+(\S+)", output) if r : package_version = r.groups()[0] # e.g. package_version is "20151119.142.0_k3.12.28_4-1.1" print "msft-lis-rdma-kmp-default package version = " + package_version r = re.match("\d{8}\.(%s).+" % NdDriverVersion, package_version) # NdDriverVersion should be at the end of package version if not r : #host ND version is the same as the package version, do an update print "ND and package version don't match, doing an update" RunGetOutput("zypper --non-interactive remove msft-lis-rdma-kmp-default") InstallRDMADriver(NdDriverVersion) RunGetOutput("reboot") else : print "ND and package version match, not doing an update" else : print "msft-lis-rdma-kmp-default not found, installing new version" InstallRDMADriver(NdDriverVersion) RunGetOutput("reboot"); ================================================ FILE: RDMAUpdate/test.ps1 ================================================ Add-AzureRmAccount Set-AzureRmContext -SubscriptionName "OSTC Shanghai Dev" $RGName = 'andliu-northus' $VmName = 'andliu-sles12' $Location = 'North Central US' $ExtensionName = 'RDMAUpdateForLinux' $Publisher = 'Microsoft.OSTCExtensions' $Version = "0.1" $PublicConf = '{}' $PrivateConf = '{}' Set-AzureRmVMExtension -ResourceGroupName $RGName -VMName $VmName -Location $Location ` -Name $ExtensionName -Publisher $Publisher -ExtensionType $ExtensionName ` -TypeHandlerVersion $Version -Settingstring $PublicConf -ProtectedSettingString $PrivateConf ================================================ FILE: README.md ================================================ # Linux extensions for Microsoft Azure IaaS This project provides the source code of Linux extensions for Microsoft Azure IaaS. VM Extensions are injected components authored by Microsoft and Partners into Linux VM (IaaS) to enable software and configuration automation. You can read the document [about virtual machine extensions and features](https://azure.microsoft.com/en-us/documentation/articles/virtual-machines-extensions-features/). # Extension List | Name | Lastest Version | Description | |:---|:---|:---| | [Custom Script](./CustomScript) | 1.5 | Allow the owner of the Azure Virtual Machines to run customized scripts in the VM | | [DSC](./DSC) | 2.71 | Allow the owner of the Azure Virtual Machines to configure the VM using Windows PowerShell Desired State Configuration (DSC) for Linux | | [OS Patching](./OSPatching) | 2.0 | Allow the owner of the Azure VM to configure the Linux VM patching schedule cycle | | [VM Access](./VMAccess) | [1.5](https://github.com/Azure/azure-linux-extensions/releases/tag/VMAccess-1.5.1) | Provide several ways to allow owner of the VM to get the SSH access back | | [OMS Agent](./OmsAgent) | 1.0 | Allow the owner of the Azure VM to install the omsagent and attach it to an OMS workspace | | [Diagnostic](./Diagnostic) | 3.0.129 | Allow the owner of the Azure Virtual Machines to obtain diagnostic data for a Linux virtual machine | | [Backup](./VMBackup) | 1.0.9124.0 | Provide application consistent backup of the virtual machine(Needs to be used in conjunction with [Azure Backup](https://azure.microsoft.com/services/backup/)) | # Contributing Please refer to the [Contribution Guide](./docs/contribution-guide.md). # Known Issues 1. When you run the PowerShell command "Set-AzureVMExtension" on Linux VM, you may hit following error: "Provision Guest Agent must be enabled on the VM object before setting IaaS VM Access Extension". * Root Cause: When you create the Linux VM via portal, the value of provision guest agent on the VM is not always set to "True". If your VM is created using PowerShell or using the Azure new portal, you will not see this issue. * Resolution: Add the following PowerShell command to set the ProvisionGuestAgent to "True". ```powershell $vm = Get-AzureVM -ServiceName 'MyServiceName' -Name 'MyVMName' $vm.GetInstance().ProvisionGuestAgent = $true ``` # Support The extensions in this repository are tested against Python 2.7 and higher. The extensions in this repository use OpenSSL 1.0 and higher. ----- This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. ================================================ FILE: SECURITY.md ================================================ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK --> ## Security Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) * Full paths of source file(s) related to the manifestation of the issue * The location of the affected source code (tag/branch/commit or direct URL) * Any special configuration required to reproduce the issue * Step-by-step instructions to reproduce the issue * Proof-of-concept or exploit code (if possible) * Impact of the issue, including how an attacker might exploit the issue This information will help us triage your report more quickly. If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. ## Preferred Languages We prefer all communications to be in English. ## Policy Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). <!-- END MICROSOFT SECURITY.MD BLOCK --> ================================================ FILE: SampleExtension/HandlerManifest.json ================================================ [{ "name": "SampleExtension", "version": 1.0, "handlerManifest": { "installCommand": "./install.py", "uninstallCommand": "./uninstall.py", "updateCommand": "./update.py", "enableCommand": "./enable.py", "disableCommand": "./disable.py", "rebootAfterInstall": false, "reportHeartbeat": false } }] ================================================ FILE: SampleExtension/disable.py ================================================ #!/usr/bin/env python from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util ExtensionShortName = "SampleExtension" def main(): waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." % ExtensionShortName) operation = "disable" status = "success" msg = "Disabled successfully." hutil = parse_context(operation) hutil.log("Start to disable.") hutil.log(msg) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context(operation) return hutil if __name__ == '__main__' : main() ================================================ FILE: SampleExtension/enable.py ================================================ #!/usr/bin/env python from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util ExtensionShortName = "SampleExtension" def main(): #Global Variables definition waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) operation = "enable" status = "success" msg = "Enabled successfully." hutil = parse_context(operation) hutil.log("Start to enable.") public_settings = hutil.get_public_settings() name = public_settings.get("name") if name: hutil.log("Hello {0}".format(name)) else: hutil.error("The name in public settings is not provided.") hutil.log(msg) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context(operation) return hutil if __name__ == '__main__' : main() ================================================ FILE: SampleExtension/install.py ================================================ #!/usr/bin/env python from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util ExtensionShortName = "SampleExtension" def main(): #Global Variables definition waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) operation = "install" status = "success" msg = "Installed successfully." hutil = parse_context(operation) hutil.log("Start to install.") hutil.log(msg) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context(operation) return hutil if __name__ == '__main__' : main() ================================================ FILE: SampleExtension/references ================================================ Utils/ ================================================ FILE: SampleExtension/uninstall.py ================================================ #!/usr/bin/env python from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util ExtensionShortName = "SampleExtension" def main(): #Global Variables definition waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) operation = "uninstall" status = "success" msg = "Uninstalled successfully." hutil = parse_context(operation) hutil.log("Start to uninstall.") hutil.log(msg) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context(operation) return hutil if __name__ == '__main__' : main() ================================================ FILE: SampleExtension/update.py ================================================ #!/usr/bin/env python from Utils.WAAgentUtil import waagent import Utils.HandlerUtil as Util ExtensionShortName = "SampleExtension" def main(): #Global Variables definition waagent.LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("%s started to handle." %(ExtensionShortName)) operation = "update" status = "success" msg = "Updated successfully." hutil = parse_context(operation) hutil.log("Start to update.") hutil.log(msg) hutil.do_exit(0, operation, status, '0', msg) def parse_context(operation): hutil = Util.HandlerUtility(waagent.Log, waagent.Error) hutil.do_parse_context(operation) return hutil if __name__ == '__main__' : main() ================================================ FILE: TestHandlerLinux/HandlerManifest.json ================================================ [{ "name": "TestHandlerLinux", "version": 1.1, "handlerManifest": { "installCommand": "installer/install.py", "uninstallCommand": "installer/uninstall.py", "updateCommand": "bin/update.py", "enableCommand": "bin/enable.py", "disableCommand": "bin/disable.py", "rebootAfterInstall": false, "reportHeartbeat": true } }] ================================================ FILE: TestHandlerLinux/bin/#heartbeat.py# ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Heartbeat example """ import os import imp import time waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("hearbeat.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Hearbeat') LoggerInit('/var/log/'+name+'_Hearbeat.log','/dev/stdout') waagent.Log(name+" - hearbeat.py starting.") logfile=waagent.Log pid=None pidfile='./service_pid.txt' retries=5 waagent.SetFileContents('./heartbeat.pid',str(os.getpid())) while(True): if os.path.exists(pidfile): pid=waagent.GetFileContents('./service_pid.txt') if waagent.Run("ps --no-headers " + str(pid),chk_err=False) == 0: # running retries=5 waagent.Log(name+" service.py is running with PID="+pid) hutil.doHealthReport(heartbeat_file,'Ready','0','service.py is running.') time.sleep(30) continue else: # died -- retries and wait for 2 min retries-=1 waagent.Error(name+" service.py is Not running.") if retries==4: hutil.doHealthReport(heartbeat_file,'NotRunning','1','ERROR - service.py Unknown or NOT running') if retries!=0: time.sleep(120) else: break else: # dead. report not ready waagent.Error(name+" service.py is Not running.") hutil.doHealthReport(heartbeat_file,'NotReady','1','ERROR - service.py is NOT running') break waagent.Log(name+" heartbeat.py exiting. service.py is NOT running") ================================================ FILE: TestHandlerLinux/bin/disable.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Diable example """ import os import imp import time import json waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("disable.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Disable') LoggerInit('/var/log/'+name+'_Disable.log','/dev/stdout') waagent.Log(name+" - disable.py starting.") logfile=waagent.Log hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, 'Disable', 'transitioning', '0', 'Disabling', 'Process Config', 'transitioning', '0', 'Parsing ' + settings_file) hutil.doHealthReport(heartbeat_file,'NotReady','0','Proccessing Settings') error_string='' pid=None pidfile='./service_pid.txt' if not os.path.isfile(pidfile): error_string += pidfile +" is missing." error_string = "Error: " + error_string waagent.Error(error_string) hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, 'Disable', 'transitioning', '0', 'Disabling', 'Process Config', 'transitioning', '0', 'Parsing ' + settings_file) else: pid = waagent.GetFileContents(pidfile) #stop service.py try: os.kill(int(pid),7) except Exception as e: pass # remove pifdile try: os.unlink(pidfile) except Exception as e: pass #Kill heartbeat.py if required. manifest = waagent.GetFileContents('./HandlerManifest.json') try: s=json.loads(manifest) except: waagent.Error('Error parsing HandlerManifest.json. Heath report will not be available.') hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Disable','NotReady','0', 'Disable service.py succeeded.' + str(pid) + ' created.', 'Exit Successfull', 'success', '0', 'Enable Completed.','NotReady','0',name+' enabled.') if s[0]['handlerManifest']['reportHeartbeat'] != True : hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Disable','NotReady','0', 'Disable service.py succeeded.' + str(pid) + ' created.', 'Exit Successfull', 'success', '0', 'Enable Completed.','Ready','0',name+' enabled.') try: pid = waagent.GetFileContents('./heartbeat.pid') except: waagent.Error('Error reading ./heartbeat.pid.') hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Disable','NotReady','0', 'Disable service.py succeeded.' + str(pid) + ' created.', 'Exit Successfull', 'success', '0', 'Enable Completed.','NotReady','0',name+' enabled.') if waagent.Run('kill '+pid)==0: waagent.Log(name+" disabled.") hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Disable','NotReady','0', 'Disable service Succeed. Health reporting stoppped.', 'Exit Successfull', 'success', '0', 'Disable Completed.','NotReady','0',name+' disabled.') ================================================ FILE: TestHandlerLinux/bin/enable.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Enable example """ import os import imp import subprocess import time import json waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("enable.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Enable') LoggerInit('/var/log/'+name+'_Enable.log','/dev/stdout') waagent.Log(name+" - enable.py starting.") logfile=waagent.Log hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name,'Enable', 'NotReady', '0', 'Enabling', 'Process Config', 'NotReady', '0', 'Parsing ' + settings_file) pub="" priv = "" # process the config info from public and private config try: pub = config['runtimeSettings'][0]['handlerSettings']['publicSettings'] except: waagent.Error("json threw an exception processing config PublicSettings.") try: priv = config['runtimeSettings'][0]['handlerSettings']['protectedSettings'] except: waagent.Error("json threw an exception processing config protectedSettings.") waagent.Log("PublicConfig =" + repr(pub) ) port=None if len(pub): try: port = pub['port'] except: waagent.Error("json threw an exception processing public setting: port") waagent.Log("ProtectedConfig =" + repr(priv) ) if len(priv): try: port = priv['port'] except: waagent.Error("json threw an exception processing protected setting: port") if port == None: port = "3000" error_string=None if port == None: error_string += "ServicePort is empty. " error_string = "Error: " + error_string waagent.Error(error_string) hutil.doExit(name,seqNo,version,1,status_file,heartbeat_file,'Install/Enable','errior','1', 'Install Failed', 'Parse Config', 'error', '1',error_string,'NotReady','1','Exiting') error_string=None waagent.SetFileContents('./resources/service_port.txt',port) error_string='' if port == None: error_string += "ServicePort is empty. " error_string = "Error: " + error_string waagent.Error(error_string) hutil.doExit(name,seqNo,version,1,status_file,heartbeat_file,'Enable','NotReady','1', 'Enable Failed', 'Read service_port.txt', 'NotReady', '1',error_string,'NotReady','1','Exiting') #if already running, kill and spawn new service.py to get current port pid=None pathdir='/usr/sbin' filepath=pathdir+'/service.py' pidfile='./service_pid.txt' if os.path.exists(pidfile): pid=waagent.GetFileContents('./service_pid.txt') try : os.kill(int(pid),7) except Exception as e: pass try: os.unlink(pidfile) except Exception as e: pass time.sleep(3) # wait for the socket to close try: pid = subprocess.Popen(filepath+' -p ' + port,shell=True,cwd=pathdir).pid except Exception as e: waagent.Error('Exception launching ' + filepath + str(e)) if pid == None or pid < 1 : waagent.Error('Error launching ' + filepath + '.') else : waagent.Log("Spawned "+ filepath + " PID = " + str(pid)) waagent.SetFileContents('./service_pid.txt',str(pid)) # report ready waagent.Log(name+" enabled.") hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name,'Enable','Ready','0', 'Enable service Succeed.', 'Exit Successfull', 'Ready', '0', 'Enable Completed.') #Spawn heartbeat.py if required. manifest = waagent.GetFileContents('./HandlerManifest.json') s=None try: s=json.loads(manifest) except: waagent.Error('Error parsing HandlerManifest.json. Health reports will not be available.') hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Enable','Ready','0', 'Enable service Succeed. Health reports will not be available.', 'Exit Successfull', 'success', '0', 'Enable Completed.','Ready','0',name+' enabled.') if s and s[0]['handlerManifest']['reportHeartbeat'] != True : waagent.Log('No heartbeat required. Health reports will not be available.') hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Enable','Ready','0', 'Enable service Succeed. Health reports will not be available.', 'Exit Successfull', 'success', '0', 'Enable Completed.','Ready','0',name+' enabled.') dirpath=os.path.realpath('./') try: pid = subprocess.Popen(dirpath+'/bin/heartbeat.py',shell=True,cwd=dirpath).pid except: waagent.Error('Error launching'+dirpath+'/bin/heartbeat.py! Health reports will not be available.') hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Enable','Ready','0', 'Enable service Succeed. Health reports will not be available.', 'Exit Successfull', 'success', '0', 'Enable Completed.','Ready','0',name+' enabled.') waagent.Log(name+" heartbeat.py started Health reports are available.") hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Enable','Ready','0', 'Enable service Succeed. Health reports are available.', 'Exit Successfull', 'success', '0', 'Enable Completed.','Ready','0',name+' enabled.') ================================================ FILE: TestHandlerLinux/bin/heartbeat.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Heartbeat example """ import os import imp import time waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("hearbeat.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Hearbeat') LoggerInit('/var/log/'+name+'_Hearbeat.log','/dev/stdout') waagent.Log(name+" - hearbeat.py starting.") logfile=waagent.Log pid=None pidfile='./service_pid.txt' retries=5 waagent.SetFileContents('./heartbeat.pid',str(os.getpid())) while(True): if os.path.exists(pidfile): pid=waagent.GetFileContents('./service_pid.txt') if waagent.Run("ps --no-headers " + str(pid),chk_err=False) == 0: # running retries=5 waagent.Log(name+" service.py is running with PID="+pid) hutil.doHealthReport(heartbeat_file,'Ready','0','service.py is running.') time.sleep(30) continue else: # died -- retries and wait for 2 min retries-=1 waagent.Error(name+" service.py is Not running.") if retries==4: hutil.doHealthReport(heartbeat_file,'NotRunning','1','ERROR - service.py Unknown or NOT running') if retries!=0: time.sleep(120) else: break else: # dead. report not ready waagent.Error(name+" service.py is Not running.") hutil.doHealthReport(heartbeat_file,'NotReady','1','ERROR - service.py is NOT running') break waagent.Log(name+" heartbeat.py exiting. service.py is NOT running") ================================================ FILE: TestHandlerLinux/bin/service.py ================================================ #!/usr/bin/env python import imp """ service example """ resources_dir = 'RESOURCES_PATH' mypydoc=imp.load_source('mypydoc',resources_dir+'/mypydoc.py') mypydoc.cli() ================================================ FILE: TestHandlerLinux/bin/update.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Update example Reads port from Public Config if present. Creates service_port.txt in resources dir. Copies the service to /usr/bin and updates it with the resource path. """ import os import sys import imp import time waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("update.py starting.") waagent.MyDistro=waagent.GetMyDistro() logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Update') LoggerInit('/var/log/'+name+'_Update.log','/dev/stdout') waagent.Log(name+" - update.py starting.") logfile=waagent.Log hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, 'Update', 'transitioning', '0', 'Updating', 'Process Config', 'transitioning', '0', 'Parsing ' + settings_file) hutil.doHealthReport(heartbeat_file,'NotReady','0','Proccessing Settings') # capture the config info from previous installation # argv[1] is the path to the previous version. waagent.SetFileContents('./resources/service_port.txt',waagent.GetFileContents(sys.argv[1]+'/resources/service_port.txt')) # move the service to sbin waagent.SetFileContents('/usr/sbin/service.py',waagent.GetFileContents('./bin/service.py')) waagent.ReplaceStringInFile('/usr/sbin/service.py','RESOURCES_PATH',os.path.realpath('./resources')) os.chmod('/usr/sbin/service.py',0700) # report ready waagent.Log(name+"updating completed.") hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Update','success','0', 'Update Succeeded.', 'Exit Successfull', 'success', '0', 'Updating Completed.','Ready','0',name+' update completed.') ================================================ FILE: TestHandlerLinux/installer/install.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Install example Reads port from Public Config if present. Creates service_port.txt in resources dir. """ import os import imp import time waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("install.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Install') LoggerInit('/var/log/'+name+'_Install.log','/dev/stdout') waagent.Log(name+" - install.py starting.") logfile=waagent.Log hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, 'Install', 'transitioning', '0', 'Installing', 'Process Config', 'transitioning', '0', 'Parsing ' + settings_file) hutil.doHealthReport(heartbeat_file,'NotReady','0','Proccessing Settings') pub="" priv = "" # process the config info from public and private config try: pub = config['runtimeSettings'][0]['handlerSettings']['publicSettings'] except: waagent.Error("json threw an exception processing config PublicSettings.") try: priv = config['runtimeSettings'][0]['handlerSettings']['protectedSettings'] except: waagent.Error("json threw an exception processing config protectedSettings.") waagent.Log("PublicConfig =" + repr(pub) ) port=None if len(pub): try: port = pub['port'] except: waagent.Error("json threw an exception processing public setting: port") waagent.Log("ProtectedConfig =" + repr(priv) ) if len(priv): try: port = priv['port'] except: waagent.Error("json threw an exception processing protected setting: port") if port == None: port = "3000" error_string=None if port == None: error_string += "ServicePort is empty. " error_string = "Error: " + error_string waagent.Error(error_string) hutil.doExit(name,seqNo,version,1,status_file,heartbeat_file,'Install/Enable','errior','1', 'Install Failed', 'Parse Config', 'error', '1',error_string,'NotReady','1','Exiting') error_string=None waagent.SetFileContents('./resources/service_port.txt',port) # move the service to sbin waagent.SetFileContents('/usr/sbin/service.py',waagent.GetFileContents('./bin/service.py')) waagent.ReplaceStringInFile('/usr/sbin/service.py','RESOURCES_PATH',os.path.realpath('./resources')) os.chmod('/usr/sbin/service.py',0700) # report ready waagent.Log("HandlerTestLinux installation completed.") hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Install','success','0', 'Install Succeeded.', 'Exit Successfull', 'success', '0', 'Installation Completed.','Ready','0',name+' installation completed.') ================================================ FILE: TestHandlerLinux/installer/uninstall.py ================================================ #!/usr/bin/env python """ Example Azure Handler script for Linux IaaS Diable example """ import os import imp import time waagent=imp.load_source('waagent','/usr/sbin/waagent') from waagent import LoggerInit hutil=imp.load_source('HandlerUtil','./resources/HandlerUtil.py') LoggerInit('/var/log/waagent.log','/dev/stdout') waagent.Log("uninstall.py starting.") logfile=waagent.Log name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config=hutil.doParse(logfile,'Uninstall') waagent.Log(name+" - uninstall.py starting.") logfile=waagent.Log hutil.doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, 'Uninstall', 'transitioning', '0', 'Uninstalling', 'Process Config', 'transitioning', '0', 'Parsing ' + settings_file) hutil.doHealthReport(heartbeat_file,'NotReady','0','Proccessing Settings') error_string=None servicefile='/usr/sbin/service.py' if not os.path.isfile(servicefile): error_string += servicefile +" is missing." error_string = "Error: " + error_string waagent.Error(error_string) hutil.doExit(name,seqNo,version,1,status_file,heartbeat_file,'Uninstall','error','1', 'Uninstall Failed', 'Remove service.py failed.', 'error', '1',error_string,'NotReady','1','Exiting') # remove os.unlink(servicefile) # report ready waagent.Log(name+" uninstalled.") hutil.doExit(name,seqNo,version,0,status_file,heartbeat_file,'Uninstall','success','0', 'Uninstall service.py Succeeded', 'Exit Successfull', 'success', '0', 'Uninstall Completed.','Ready','0',name+' uninstalled.') ================================================ FILE: TestHandlerLinux/manifest.xml ================================================ <?xml version='1.0' encoding='utf-8' ?> <ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure"> <ProviderNameSpace>Microsoft.OSTCExtensions</ProviderNameSpace> <Type>OSTCTestHandlerLinux</Type> <Version>1.1</Version> <Label>Windows Azure Example Extension Handler for Linux Virtual Machines</Label> <HostingResources>VmRole</HostingResources> <MediaLink></MediaLink> <Description>Windows Azure Example Extension Handler for Linux Virtual Machines</Description> <IsInternalExtension>true</IsInternalExtension> <Eula>https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt</Eula> <PrivacyUri>http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx</PrivacyUri> <HomepageUri>https://github.com/Azure/azure-linux-extensions</HomepageUri> <IsJsonExtension>true</IsJsonExtension> <SupportedOS>Linux</SupportedOS> <CompanyName>Microsoft</CompanyName> <!--%REGIONS%--> </ExtensionImage> ================================================ FILE: TestHandlerLinux/references ================================================ Utils/ ================================================ FILE: TestHandlerLinux/resources/HandlerUtil.py ================================================ #!/usr/bin/env python """ Handler library for Linux IaaS JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "<your log folder location>", "configFolder": "<your config folder location>", "statusFolder": "<your status folder location>", "heartbeatFile": "<your heartbeat file location>", } }] { "handlerSettings": { "protectedSettings": { "Password": "UserPassword" }, "publicSettings": { "UserName": "UserName", "Expiration": "Password expiration date in yyy-mm-dd" } } } Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Status uses either non-localized 'message' or localized 'formattedMessage' but not both. { "version": 1.0, "timestampUTC": "<current utc time>", "status" : { "name": "<Handler workload name>", "operation": "<name of the operation being performed>", "configurationAppliedTime": "<UTC time indicating when the configuration was last successfully applied>", "status": "<transitioning | error | success | warning>", "code": <Valid integer status code>, "message": { "id": "id of the localized resource", "params": [ "MyParam0", "MyParam1" ] }, "formattedMessage": { "lang": "Lang[-locale]", "message": "formatted user message" } } } """ import os import sys import imp import base64 import json import time # waagent has no '.py' therefore create waagent module import manually. waagent=imp.load_source('waagent','/usr/sbin/waagent') def doParse(Log,operation): handler_env=None config=None ctxt=None code=0 # get the HandlerEnvironment.json. it should always be in ./ waagent.Log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file='./HandlerEnvironment.json' if not os.path.isfile(handler_env_file): waagent.Error("Unable to locate " + handler_env_file) sys.exit(1) ctxt=waagent.GetFileContents(handler_env_file) if ctxt == None : waagent.Error("Unable to read " + handler_env_file) try: handler_env=json.loads(ctxt) except: pass if handler_env == None : waagent.Error("JSON error processing " + handler_env_file) sys.exit(1) if type(handler_env) == list: handler_env = handler_env[0] # parse the dirs name='NULL' seqNo='0' version='0.0' config_dir='./' log_dir='./' status_dir='./' heartbeat_file='NULL.log' name=handler_env['name'] seqNo=handler_env['seqNo'] version=str(handler_env['version']) config_dir=handler_env['handlerEnvironment']['configFolder'] log_dir=handler_env['handlerEnvironment']['logFolder'] status_dir=handler_env['handlerEnvironment']['statusFolder'] heartbeat_file=handler_env['handlerEnvironment']['heartbeatFile'] # always get the newest settings file code,settings_file=waagent.RunGetOutput('ls -rt ' + config_dir + '/*.settings | tail -1') if code != 0: waagent.Error("Unable to locate a .settings file!") sys.exit(1) settings_file=settings_file[:-1] # get our incarnation # from the number of the .settings file incarnation=os.path.splitext(os.path.basename(settings_file))[0] waagent.Log('Incarnation is ' + incarnation) status_file=status_dir+'/'+incarnation+'.status' waagent.Log("setting file path is" + settings_file) ctxt=None ctxt=waagent.GetFileContents(settings_file) if ctxt == None : waagent.Error('Unable to read ' + settings_file + '. ') doExit(name,seqNo,version,1,status_file,heartbeat_file,operation,'error','1', operation+' Failed', 'Read .settings', 'error', '1','Unable to read ' + settings_file + '. ','NotReady','1','Exiting') waagent.Log("Read: " + ctxt) # parse json config = None try: config=json.loads(ctxt) except: waagent.Error('JSON exception decoding ' + ctxt) if config == None: waagent.Error("JSON error processing " + settings_file) return (name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config) # doExit(name,seqNo,version,1,status_file,heartbeat_file,operation,'errior','1', operation + ' Failed', 'Parse Config', 'error', '1', 'JSON error processing ' + settings_file,'NotReady','1','Exiting') # sys.exit(1) print repr(config) if config['runtimeSettings'][0]['handlerSettings'].has_key('protectedSettings') == True: thumb=config['runtimeSettings'][0]['handlerSettings']['protectedSettingsCertThumbprint'] cert=waagent.LibDir+'/'+thumb+'.crt' pkey=waagent.LibDir+'/'+thumb+'.prv' waagent.SetFileContents('/tmp/kk',config['runtimeSettings'][0]['handlerSettings']['protectedSettings']) cleartxt=None cleartxt=waagent.RunGetOutput("base64 -d /tmp/kk | openssl smime -inform DER -decrypt -recip " + cert + " -inkey " + pkey )[1] if cleartxt == None: waagent.Error("OpenSSh decode error using thumbprint " + thumb ) doExit(name,seqNo,version,1,status_file,heartbeat_file,operation,'errior','1', operation + ' Failed', 'Parse Config', 'error', '1', 'OpenSsh decode error using thumbprint ' + thumb,'NotReady','1','Exiting') sys.exit(1) jctxt='' try: jctxt=json.loads(cleartxt) except: waagent.Error('JSON exception decoding ' + cleartxt) config['runtimeSettings'][0]['handlerSettings']['protectedSettings']=jctxt waagent.Log('Config decoded correctly.') return (name,seqNo,version,config_dir,log_dir,settings_file,status_file,heartbeat_file,config) def doStatusReport(name,seqNo,version,stat_file,current_utc, started_at_utc, workload_name, operation_name, status, status_code, status_message, sub_workload_name, sub_status, sub_status_code, sub_status_message): #'{"handlerName":"Chef.Bootstrap.WindowsAzure.ChefClient","handlerVersion":"11.12.0.0","status":"NotReady","code":1,"formattedMessage":{"lang":"en-US","message":"Enable command of plugin (name: Chef.Bootstrap.WindowsAzure.ChefClient, version 11.12.0.0) failed with exception Command C:/Packages/Plugins/Chef.Bootstrap.WindowsAzure.ChefClient/11.12.0.0/enable.cmd of Chef.Bootstrap.WindowsAzure.ChefClient has exited with Exit code: 1"}},{"handlerName":"Microsoft.Compute.BGInfo","handlerVersion":"1.1","status":"Ready","formattedMessage":{"lang":"en-US","message":"plugin (name: Microsoft.Compute.BGInfo, version: 1.1) enabled successfully."}}' stat_rept='{"handlerName":"' + name + '","handlerVersion":"'+version+ '","status":"' +status + '","code":' + status_code + ',"formattedMessage":{"lang":"en-US","message":"' + status_message + '"}}' cur_file=stat_file+'_current' with open(cur_file,'w+') as f: f.write(stat_rept) # if inc.status exists, rename the inc.status to inc.status_sent if os.path.exists(stat_file) == True: os.rename(stat_file,stat_file+'_sent') # rename inc.status_current to inc.status os.rename(cur_file,stat_file) # remove inc.status_sent if os.path.exists(stat_file+'_sent') == True: os.unlink(stat_file+'_sent') def doHealthReport(heartbeat_file,status,code,message): # heartbeat health_report='[{"version":"1.0","heartbeat":{"status":"' + status+ '","code":"'+ code + '","Message":"' + message + '"}}]' if waagent.SetFileContents(heartbeat_file,health_report) == None : waagent.Error('Unable to wite heartbeat info to ' + heartbeat_file) def doExit(name,seqNo,version,exit_code,status_file,heartbeat_file,operation,status,code,message,sub_operation,sub_status,sub_code,sub_message,health_state,health_code,health_message): doStatusReport(name,seqNo,version,status_file,time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),name, operation,status,code,message,sub_operation,sub_status,sub_code,sub_message) doHealthReport(heartbeat_file,'NotReady','1','Exiting') sys.exit(exit_code) ================================================ FILE: TestHandlerLinux/resources/mypydoc.py ================================================ #! /usr/bin/python2.7 # -*- coding: latin-1 -*- """Generate Python documentation in HTML or text for interactive use. In the Python interpreter, do "from pydoc import help" to provide online help. Calling help(thing) on a Python object documents the object. Or, at the shell command line outside of Python: Run "pydoc <name>" to show documentation on something. <name> may be the name of a function, module, package, or a dotted reference to a class or function within a module or module in a package. If the argument contains a path segment delimiter (e.g. slash on Unix, backslash on Windows) it is treated as the path to a Python source file. Run "pydoc -k <keyword>" to search for a keyword in the synopsis lines of all available modules. Run "pydoc -p <port>" to start an HTTP server on a given port on the local machine to generate documentation web pages. For platforms without a command line, "pydoc -g" starts the HTTP server and also pops up a little window for controlling it. Run "pydoc -w <name>" to write out the HTML documentation for a module to a file named "<name>.html". Module docs for core modules are assumed to be in http://docs.python.org/library/ This can be overridden by setting the PYTHONDOCS environment variable to a different URL or to a local directory containing the Library Reference Manual pages. """ __author__ = "Ka-Ping Yee <ping@lfw.org>" __date__ = "26 February 2001" __version__ = "$Revision: 88564 $" __credits__ = """Guido van Rossum, for an excellent programming language. Tommy Burnette, the original creator of manpy. Paul Prescod, for all his work on onlinehelp. Richard Chamberlain, for the first implementation of textdoc. """ # Known bugs that can't be fixed here: # - imp.load_module() cannot be prevented from clobbering existing # loaded modules, so calling synopsis() on a binary module file # changes the contents of any existing module with the same name. # - If the __file__ attribute on a module is a relative path and # the current directory is changed with os.chdir(), an incorrect # path will be displayed. import sys, imp, os, re, types, inspect, __builtin__, pkgutil, warnings from repr import Repr from string import expandtabs, find, join, lower, split, strip, rfind, rstrip from traceback import extract_tb try: from collections import deque except ImportError: # Python 2.3 compatibility class deque(list): def popleft(self): return self.pop(0) # --------------------------------------------------------- common routines def pathdirs(): """Convert sys.path into a list of absolute, existing, unique paths.""" dirs = [] normdirs = [] for dir in sys.path: dir = os.path.abspath(dir or '.') normdir = os.path.normcase(dir) if normdir not in normdirs and os.path.isdir(dir): dirs.append(dir) normdirs.append(normdir) return dirs def getdoc(object): """Get the doc string or comments for an object.""" result = inspect.getdoc(object) or inspect.getcomments(object) return result and re.sub('^ *\n', '', rstrip(result)) or '' def splitdoc(doc): """Split a doc string into a synopsis line (if any) and the rest.""" lines = split(strip(doc), '\n') if len(lines) == 1: return lines[0], '' elif len(lines) >= 2 and not rstrip(lines[1]): return lines[0], join(lines[2:], '\n') return '', join(lines, '\n') def classname(object, modname): """Get a class name and qualify it with a module name if necessary.""" name = object.__name__ if object.__module__ != modname: name = object.__module__ + '.' + name return name def isdata(object): """Check if an object is of a type that probably means it's data.""" return not (inspect.ismodule(object) or inspect.isclass(object) or inspect.isroutine(object) or inspect.isframe(object) or inspect.istraceback(object) or inspect.iscode(object)) def replace(text, *pairs): """Do a series of global replacements on a string.""" while pairs: text = join(split(text, pairs[0]), pairs[1]) pairs = pairs[2:] return text def cram(text, maxlen): """Omit part of a string if needed to make it fit in a maximum length.""" if len(text) > maxlen: pre = max(0, (maxlen-3)//2) post = max(0, maxlen-3-pre) return text[:pre] + '...' + text[len(text)-post:] return text _re_stripid = re.compile(r' at 0x[0-9a-f]{6,16}(>+)$', re.IGNORECASE) def stripid(text): """Remove the hexadecimal id from a Python object representation.""" # The behaviour of %p is implementation-dependent in terms of case. return _re_stripid.sub(r'\1', text) def _is_some_method(obj): return inspect.ismethod(obj) or inspect.ismethoddescriptor(obj) def allmethods(cl): methods = {} for key, value in inspect.getmembers(cl, _is_some_method): methods[key] = 1 for base in cl.__bases__: methods.update(allmethods(base)) # all your base are belong to us for key in methods.keys(): methods[key] = getattr(cl, key) return methods def _split_list(s, predicate): """Split sequence s via predicate, and return pair ([true], [false]). The return value is a 2-tuple of lists, ([x for x in s if predicate(x)], [x for x in s if not predicate(x)]) """ yes = [] no = [] for x in s: if predicate(x): yes.append(x) else: no.append(x) return yes, no def visiblename(name, all=None, obj=None): """Decide whether to show documentation on a variable.""" # Certain special names are redundant. _hidden_names = ('__builtins__', '__doc__', '__file__', '__path__', '__module__', '__name__', '__slots__', '__package__') if name in _hidden_names: return 0 # Private names are hidden, but special names are displayed. if name.startswith('__') and name.endswith('__'): return 1 # Namedtuples have public fields and methods with a single leading underscore if name.startswith('_') and hasattr(obj, '_fields'): return 1 if all is not None: # only document that which the programmer exported in __all__ return name in all else: return not name.startswith('_') def classify_class_attrs(object): """Wrap inspect.classify_class_attrs, with fixup for data descriptors.""" def fixup(data): name, kind, cls, value = data if inspect.isdatadescriptor(value): kind = 'data descriptor' return name, kind, cls, value return map(fixup, inspect.classify_class_attrs(object)) # ----------------------------------------------------- module manipulation def ispackage(path): """Guess whether a path refers to a package directory.""" if os.path.isdir(path): for ext in ('.py', '.pyc', '.pyo'): if os.path.isfile(os.path.join(path, '__init__' + ext)): return True return False def source_synopsis(file): line = file.readline() while line[:1] == '#' or not strip(line): line = file.readline() if not line: break line = strip(line) if line[:4] == 'r"""': line = line[1:] if line[:3] == '"""': line = line[3:] if line[-1:] == '\\': line = line[:-1] while not strip(line): line = file.readline() if not line: break result = strip(split(line, '"""')[0]) else: result = None return result def synopsis(filename, cache={}): """Get the one-line summary out of a module file.""" mtime = os.stat(filename).st_mtime lastupdate, result = cache.get(filename, (None, None)) if lastupdate is None or lastupdate < mtime: info = inspect.getmoduleinfo(filename) try: file = open(filename) except IOError: # module can't be opened, so skip it return None if info and 'b' in info[2]: # binary modules have to be imported try: module = imp.load_module('__temp__', file, filename, info[1:]) except: return None result = (module.__doc__ or '').splitlines()[0] del sys.modules['__temp__'] else: # text modules can be directly examined result = source_synopsis(file) file.close() cache[filename] = (mtime, result) return result class ErrorDuringImport(Exception): """Errors that occurred while trying to import something to document it.""" def __init__(self, filename, exc_info): exc, value, tb = exc_info self.filename = filename self.exc = exc self.value = value self.tb = tb def __str__(self): exc = self.exc if type(exc) is types.ClassType: exc = exc.__name__ return 'problem in %s - %s: %s' % (self.filename, exc, self.value) def importfile(path): """Import a Python source file or compiled file given its path.""" magic = imp.get_magic() file = open(path, 'r') if file.read(len(magic)) == magic: kind = imp.PY_COMPILED else: kind = imp.PY_SOURCE file.close() filename = os.path.basename(path) name, ext = os.path.splitext(filename) file = open(path, 'r') try: module = imp.load_module(name, file, path, (ext, 'r', kind)) except: raise ErrorDuringImport(path, sys.exc_info()) file.close() return module def safeimport(path, forceload=0, cache={}): """Import a module; handle errors; return None if the module isn't found. If the module *is* found but an exception occurs, it's wrapped in an ErrorDuringImport exception and reraised. Unlike __import__, if a package path is specified, the module at the end of the path is returned, not the package at the beginning. If the optional 'forceload' argument is 1, we reload the module from disk (unless it's a dynamic extension).""" try: # If forceload is 1 and the module has been previously loaded from # disk, we always have to reload the module. Checking the file's # mtime isn't good enough (e.g. the module could contain a class # that inherits from another module that has changed). if forceload and path in sys.modules: if path not in sys.builtin_module_names: # Avoid simply calling reload() because it leaves names in # the currently loaded module lying around if they're not # defined in the new source file. Instead, remove the # module from sys.modules and re-import. Also remove any # submodules because they won't appear in the newly loaded # module's namespace if they're already in sys.modules. subs = [m for m in sys.modules if m.startswith(path + '.')] for key in [path] + subs: # Prevent garbage collection. cache[key] = sys.modules[key] del sys.modules[key] module = __import__(path) except: # Did the error occur before or after the module was found? (exc, value, tb) = info = sys.exc_info() if path in sys.modules: # An error occurred while executing the imported module. raise ErrorDuringImport(sys.modules[path].__file__, info) elif exc is SyntaxError: # A SyntaxError occurred before we could execute the module. raise ErrorDuringImport(value.filename, info) elif exc is ImportError and extract_tb(tb)[-1][2]=='safeimport': # The import error occurred directly in this function, # which means there is no such module in the path. return None else: # Some other error occurred during the importing process. raise ErrorDuringImport(path, sys.exc_info()) for part in split(path, '.')[1:]: try: module = getattr(module, part) except AttributeError: return None return module # ---------------------------------------------------- formatter base class class Doc: def document(self, object, name=None, *args): """Generate documentation for an object.""" args = (object, name) + args # 'try' clause is to attempt to handle the possibility that inspect # identifies something in a way that pydoc itself has issues handling; # think 'super' and how it is a descriptor (which raises the exception # by lacking a __name__ attribute) and an instance. if inspect.isgetsetdescriptor(object): return self.docdata(*args) if inspect.ismemberdescriptor(object): return self.docdata(*args) try: if inspect.ismodule(object): return self.docmodule(*args) if inspect.isclass(object): return self.docclass(*args) if inspect.isroutine(object): return self.docroutine(*args) except AttributeError: pass if isinstance(object, property): return self.docproperty(*args) return self.docother(*args) def fail(self, object, name=None, *args): """Raise an exception for unimplemented types.""" message = "don't know how to document object%s of type %s" % ( name and ' ' + repr(name), type(object).__name__) raise TypeError, message docmodule = docclass = docroutine = docother = docproperty = docdata = fail def getdocloc(self, object): """Return the location of module docs or None""" try: file = inspect.getabsfile(object) except TypeError: file = '(built-in)' docloc = os.environ.get("PYTHONDOCS", "http://docs.python.org/library") basedir = os.path.join(sys.exec_prefix, "lib", "python"+sys.version[0:3]) if (isinstance(object, type(os)) and (object.__name__ in ('errno', 'exceptions', 'gc', 'imp', 'marshal', 'posix', 'signal', 'sys', 'thread', 'zipimport') or (file.startswith(basedir) and not file.startswith(os.path.join(basedir, 'dist-packages')) and not file.startswith(os.path.join(basedir, 'site-packages')))) and object.__name__ not in ('xml.etree', 'test.pydoc_mod')): if docloc.startswith("http://"): docloc = "%s/%s" % (docloc.rstrip("/"), object.__name__) else: docloc = os.path.join(docloc, object.__name__ + ".html") else: docloc = None return docloc # -------------------------------------------- HTML documentation generator class HTMLRepr(Repr): """Class for safely making an HTML representation of a Python object.""" def __init__(self): Repr.__init__(self) self.maxlist = self.maxtuple = 20 self.maxdict = 10 self.maxstring = self.maxother = 100 def escape(self, text): return replace(text, '&', '&', '<', '<', '>', '>') def repr(self, object): return Repr.repr(self, object) def repr1(self, x, level): if hasattr(type(x), '__name__'): methodname = 'repr_' + join(split(type(x).__name__), '_') if hasattr(self, methodname): return getattr(self, methodname)(x, level) return self.escape(cram(stripid(repr(x)), self.maxother)) def repr_string(self, x, level): test = cram(x, self.maxstring) testrepr = repr(test) if '\\' in test and '\\' not in replace(testrepr, r'\\', ''): # Backslashes are only literal in the string and are never # needed to make any special characters, so show a raw string. return 'r' + testrepr[0] + self.escape(test) + testrepr[0] return re.sub(r'((\\[\\abfnrtv\'"]|\\[0-9]..|\\x..|\\u....)+)', r'<font color="#c040c0">\1</font>', self.escape(testrepr)) repr_str = repr_string def repr_instance(self, x, level): try: return self.escape(cram(stripid(repr(x)), self.maxstring)) except: return self.escape('<%s instance>' % x.__class__.__name__) repr_unicode = repr_string class HTMLDoc(Doc): """Formatter class for HTML documentation.""" # ------------------------------------------- HTML formatting utilities _repr_instance = HTMLRepr() repr = _repr_instance.repr escape = _repr_instance.escape def page(self, title, contents): """Format an HTML page.""" return ''' <!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN"> <html><head><title>Python: %s %s ''' % (title, contents) def heading(self, title, fgcol, bgcol, extras=''): """Format a page heading.""" return '''
 
 
%s
%s
''' % (bgcol, fgcol, title, fgcol, extras or ' ') def section(self, title, fgcol, bgcol, contents, width=6, prelude='', marginalia=None, gap=' '): """Format a section with a heading.""" if marginalia is None: marginalia = '' + ' ' * width + '' result = '''

''' % (bgcol, fgcol, title) if prelude: result = result + ''' ''' % (bgcol, marginalia, prelude, gap) else: result = result + ''' ''' % (bgcol, marginalia, gap) return result + '\n
 
%s
%s %s
%s
%s%s%s
' % contents def bigsection(self, title, *args): """Format a section with a big heading.""" title = '%s' % title return self.section(title, *args) def preformat(self, text): """Format literal preformatted text.""" text = self.escape(expandtabs(text)) return replace(text, '\n\n', '\n \n', '\n\n', '\n \n', ' ', ' ', '\n', '
\n') def multicolumn(self, list, format, cols=4): """Format a list of items into a multi-column list.""" result = '' rows = (len(list)+cols-1)//cols for col in range(cols): result = result + '' % (100//cols) for i in range(rows*col, rows*col+rows): if i < len(list): result = result + format(list[i]) + '
\n' result = result + '' return '%s
' % result def grey(self, text): return '%s' % text def namelink(self, name, *dicts): """Make a link for an identifier, given name-to-URL mappings.""" for dict in dicts: if name in dict: return '%s' % (dict[name], name) return name def classlink(self, object, modname): """Make a link for a class.""" name, module = object.__name__, sys.modules.get(object.__module__) if hasattr(module, name) and getattr(module, name) is object: return '%s' % ( module.__name__, name, classname(object, modname)) return classname(object, modname) def modulelink(self, object): """Make a link for a module.""" return '%s' % (object.__name__, object.__name__) def modpkglink(self, data): """Make a link for a module or package to display in an index.""" name, path, ispackage, shadowed = data if shadowed: return self.grey(name) if path: url = '%s.%s.html' % (path, name) else: url = '%s.html' % name if ispackage: text = '%s (package)' % name else: text = name return '%s' % (url, text) def markup(self, text, escape=None, funcs={}, classes={}, methods={}): """Mark up some plain text, given a context of symbols to look for. Each context dictionary maps object names to anchor names.""" escape = escape or self.escape results = [] here = 0 pattern = re.compile(r'\b((http|ftp)://\S+[\w/]|' r'RFC[- ]?(\d+)|' r'PEP[- ]?(\d+)|' r'(self\.)?(\w+))') while True: match = pattern.search(text, here) if not match: break start, end = match.span() results.append(escape(text[here:start])) all, scheme, rfc, pep, selfdot, name = match.groups() if scheme: url = escape(all).replace('"', '"') results.append('%s' % (url, url)) elif rfc: url = 'http://www.rfc-editor.org/rfc/rfc%d.txt' % int(rfc) results.append('%s' % (url, escape(all))) elif pep: url = 'http://www.python.org/dev/peps/pep-%04d/' % int(pep) results.append('%s' % (url, escape(all))) elif text[end:end+1] == '(': results.append(self.namelink(name, methods, funcs, classes)) elif selfdot: results.append('self.%s' % name) else: results.append(self.namelink(name, classes)) here = end results.append(escape(text[here:])) return join(results, '') # ---------------------------------------------- type-specific routines def formattree(self, tree, modname, parent=None): """Produce HTML for a class tree as given by inspect.getclasstree().""" result = '' for entry in tree: if type(entry) is type(()): c, bases = entry result = result + '

' result = result + self.classlink(c, modname) if bases and bases != (parent,): parents = [] for base in bases: parents.append(self.classlink(base, modname)) result = result + '(' + join(parents, ', ') + ')' result = result + '\n
' elif type(entry) is type([]): result = result + '
\n%s
\n' % self.formattree( entry, modname, c) return '
\n%s
\n' % result def docmodule(self, object, name=None, mod=None, *ignored): """Produce HTML documentation for a module object.""" name = object.__name__ # ignore the passed-in name try: all = object.__all__ except AttributeError: all = None parts = split(name, '.') links = [] for i in range(len(parts)-1): links.append( '%s' % (join(parts[:i+1], '.'), parts[i])) linkedname = join(links + parts[-1:], '.') head = '%s' % linkedname try: path = inspect.getabsfile(object) url = path if sys.platform == 'win32': import nturl2path url = nturl2path.pathname2url(path) filelink = '%s' % (url, path) except TypeError: filelink = '(built-in)' info = [] if hasattr(object, '__version__'): version = str(object.__version__) if version[:11] == '$' + 'Revision: ' and version[-1:] == '$': version = strip(version[11:-1]) info.append('version %s' % self.escape(version)) if hasattr(object, '__date__'): info.append(self.escape(str(object.__date__))) if info: head = head + ' (%s)' % join(info, ', ') docloc = self.getdocloc(object) if docloc is not None: docloc = '
Module Docs' % locals() else: docloc = '' result = self.heading( head, '#ffffff', '#7799ee', 'index
' + filelink + docloc) modules = inspect.getmembers(object, inspect.ismodule) classes, cdict = [], {} for key, value in inspect.getmembers(object, inspect.isclass): # if __all__ exists, believe it. Otherwise use old heuristic. if (all is not None or (inspect.getmodule(value) or object) is object): if visiblename(key, all, object): classes.append((key, value)) cdict[key] = cdict[value] = '#' + key for key, value in classes: for base in value.__bases__: key, modname = base.__name__, base.__module__ module = sys.modules.get(modname) if modname != name and module and hasattr(module, key): if getattr(module, key) is base: if not key in cdict: cdict[key] = cdict[base] = modname + '.html#' + key funcs, fdict = [], {} for key, value in inspect.getmembers(object, inspect.isroutine): # if __all__ exists, believe it. Otherwise use old heuristic. if (all is not None or inspect.isbuiltin(value) or inspect.getmodule(value) is object): if visiblename(key, all, object): funcs.append((key, value)) fdict[key] = '#-' + key if inspect.isfunction(value): fdict[value] = fdict[key] data = [] for key, value in inspect.getmembers(object, isdata): if visiblename(key, all, object): data.append((key, value)) doc = self.markup(getdoc(object), self.preformat, fdict, cdict) doc = doc and '%s' % doc result = result + '

%s

\n' % doc if hasattr(object, '__path__'): modpkgs = [] for importer, modname, ispkg in pkgutil.iter_modules(object.__path__): modpkgs.append((modname, name, ispkg, 0)) modpkgs.sort() contents = self.multicolumn(modpkgs, self.modpkglink) result = result + self.bigsection( 'Package Contents', '#ffffff', '#aa55cc', contents) elif modules: contents = self.multicolumn( modules, lambda key_value, s=self: s.modulelink(key_value[1])) result = result + self.bigsection( 'Modules', '#ffffff', '#aa55cc', contents) if classes: classlist = map(lambda key_value: key_value[1], classes) contents = [ self.formattree(inspect.getclasstree(classlist, 1), name)] for key, value in classes: contents.append(self.document(value, key, name, fdict, cdict)) result = result + self.bigsection( 'Classes', '#ffffff', '#ee77aa', join(contents)) if funcs: contents = [] for key, value in funcs: contents.append(self.document(value, key, name, fdict, cdict)) result = result + self.bigsection( 'Functions', '#ffffff', '#eeaa77', join(contents)) if data: contents = [] for key, value in data: contents.append(self.document(value, key)) result = result + self.bigsection( 'Data', '#ffffff', '#55aa55', join(contents, '
\n')) if hasattr(object, '__author__'): contents = self.markup(str(object.__author__), self.preformat) result = result + self.bigsection( 'Author', '#ffffff', '#7799ee', contents) if hasattr(object, '__credits__'): contents = self.markup(str(object.__credits__), self.preformat) result = result + self.bigsection( 'Credits', '#ffffff', '#7799ee', contents) return result def docclass(self, object, name=None, mod=None, funcs={}, classes={}, *ignored): """Produce HTML documentation for a class object.""" realname = object.__name__ name = name or realname bases = object.__bases__ contents = [] push = contents.append # Cute little class to pump out a horizontal rule between sections. class HorizontalRule: def __init__(self): self.needone = 0 def maybe(self): if self.needone: push('
\n') self.needone = 1 hr = HorizontalRule() # List the mro, if non-trivial. mro = deque(inspect.getmro(object)) if len(mro) > 2: hr.maybe() push('
Method resolution order:
\n') for base in mro: push('
%s
\n' % self.classlink(base, object.__module__)) push('
\n') def spill(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: try: value = getattr(object, name) except Exception: # Some descriptors may meet a failure in their __get__. # (bug #1785) push(self._docdescriptor(name, value, mod)) else: push(self.document(value, name, mod, funcs, classes, mdict, object)) push('\n') return attrs def spilldescriptors(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: push(self._docdescriptor(name, value, mod)) return attrs def spilldata(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: base = self.docother(getattr(object, name), name, mod) if (hasattr(value, '__call__') or inspect.isdatadescriptor(value)): doc = getattr(value, "__doc__", None) else: doc = None if doc is None: push('
%s
\n' % base) else: doc = self.markup(getdoc(value), self.preformat, funcs, classes, mdict) doc = '
%s' % doc push('
%s%s
\n' % (base, doc)) push('\n') return attrs attrs = filter(lambda data: visiblename(data[0], obj=object), classify_class_attrs(object)) mdict = {} for key, kind, homecls, value in attrs: mdict[key] = anchor = '#' + name + '-' + key try: value = getattr(object, name) except Exception: # Some descriptors may meet a failure in their __get__. # (bug #1785) pass try: # The value may not be hashable (e.g., a data attr with # a dict or list value). mdict[value] = anchor except TypeError: pass while attrs: if mro: thisclass = mro.popleft() else: thisclass = attrs[0][2] attrs, inherited = _split_list(attrs, lambda t: t[2] is thisclass) if thisclass is __builtin__.object: attrs = inherited continue elif thisclass is object: tag = 'defined here' else: tag = 'inherited from %s' % self.classlink(thisclass, object.__module__) tag += ':
\n' # Sort attrs by name. try: attrs.sort(key=lambda t: t[0]) except TypeError: attrs.sort(lambda t1, t2: cmp(t1[0], t2[0])) # 2.3 compat # Pump out the attrs, segregated by kind. attrs = spill('Methods %s' % tag, attrs, lambda t: t[1] == 'method') attrs = spill('Class methods %s' % tag, attrs, lambda t: t[1] == 'class method') attrs = spill('Static methods %s' % tag, attrs, lambda t: t[1] == 'static method') attrs = spilldescriptors('Data descriptors %s' % tag, attrs, lambda t: t[1] == 'data descriptor') attrs = spilldata('Data and other attributes %s' % tag, attrs, lambda t: t[1] == 'data') assert attrs == [] attrs = inherited contents = ''.join(contents) if name == realname: title = 'class %s' % ( name, realname) else: title = '%s = class %s' % ( name, name, realname) if bases: parents = [] for base in bases: parents.append(self.classlink(base, object.__module__)) title = title + '(%s)' % join(parents, ', ') doc = self.markup(getdoc(object), self.preformat, funcs, classes, mdict) doc = doc and '%s
 
' % doc return self.section(title, '#000000', '#ffc8d8', contents, 3, doc) def formatvalue(self, object): """Format an argument default value as text.""" return self.grey('=' + self.repr(object)) def docroutine(self, object, name=None, mod=None, funcs={}, classes={}, methods={}, cl=None): """Produce HTML documentation for a function or method object.""" realname = object.__name__ name = name or realname anchor = (cl and cl.__name__ or '') + '-' + name note = '' skipdocs = 0 if inspect.ismethod(object): imclass = object.im_class if cl: if imclass is not cl: note = ' from ' + self.classlink(imclass, mod) else: if object.im_self is not None: note = ' method of %s instance' % self.classlink( object.im_self.__class__, mod) else: note = ' unbound %s method' % self.classlink(imclass,mod) object = object.im_func if name == realname: title = '%s' % (anchor, realname) else: if (cl and realname in cl.__dict__ and cl.__dict__[realname] is object): reallink = '%s' % ( cl.__name__ + '-' + realname, realname) skipdocs = 1 else: reallink = realname title = '%s = %s' % ( anchor, name, reallink) if inspect.isfunction(object): args, varargs, varkw, defaults = inspect.getargspec(object) argspec = inspect.formatargspec( args, varargs, varkw, defaults, formatvalue=self.formatvalue) if realname == '': title = '%s lambda ' % name argspec = argspec[1:-1] # remove parentheses else: argspec = '(...)' decl = title + argspec + (note and self.grey( '%s' % note)) if skipdocs: return '
%s
\n' % decl else: doc = self.markup( getdoc(object), self.preformat, funcs, classes, methods) doc = doc and '
%s
' % doc return '
%s
%s
\n' % (decl, doc) def _docdescriptor(self, name, value, mod): results = [] push = results.append if name: push('
%s
\n' % name) if value.__doc__ is not None: doc = self.markup(getdoc(value), self.preformat) push('
%s
\n' % doc) push('
\n') return ''.join(results) def docproperty(self, object, name=None, mod=None, cl=None): """Produce html documentation for a property.""" return self._docdescriptor(name, object, mod) def docother(self, object, name=None, mod=None, *ignored): """Produce HTML documentation for a data object.""" lhs = name and '%s = ' % name or '' return lhs + self.repr(object) def docdata(self, object, name=None, mod=None, cl=None): """Produce html documentation for a data descriptor.""" return self._docdescriptor(name, object, mod) def index(self, dir, shadowed=None): """Generate an HTML index for a directory of modules.""" modpkgs = [] if shadowed is None: shadowed = {} for importer, name, ispkg in pkgutil.iter_modules([dir]): modpkgs.append((name, '', ispkg, name in shadowed)) shadowed[name] = 1 modpkgs.sort() contents = self.multicolumn(modpkgs, self.modpkglink) return self.bigsection(dir, '#ffffff', '#ee77aa', contents) # -------------------------------------------- text documentation generator class TextRepr(Repr): """Class for safely making a text representation of a Python object.""" def __init__(self): Repr.__init__(self) self.maxlist = self.maxtuple = 20 self.maxdict = 10 self.maxstring = self.maxother = 100 def repr1(self, x, level): if hasattr(type(x), '__name__'): methodname = 'repr_' + join(split(type(x).__name__), '_') if hasattr(self, methodname): return getattr(self, methodname)(x, level) return cram(stripid(repr(x)), self.maxother) def repr_string(self, x, level): test = cram(x, self.maxstring) testrepr = repr(test) if '\\' in test and '\\' not in replace(testrepr, r'\\', ''): # Backslashes are only literal in the string and are never # needed to make any special characters, so show a raw string. return 'r' + testrepr[0] + test + testrepr[0] return testrepr repr_str = repr_string def repr_instance(self, x, level): try: return cram(stripid(repr(x)), self.maxstring) except: return '<%s instance>' % x.__class__.__name__ class TextDoc(Doc): """Formatter class for text documentation.""" # ------------------------------------------- text formatting utilities _repr_instance = TextRepr() repr = _repr_instance.repr def bold(self, text): """Format a string in bold by overstriking.""" return join(map(lambda ch: ch + '\b' + ch, text), '') def indent(self, text, prefix=' '): """Indent text by prepending a given prefix to each line.""" if not text: return '' lines = split(text, '\n') lines = map(lambda line, prefix=prefix: prefix + line, lines) if lines: lines[-1] = rstrip(lines[-1]) return join(lines, '\n') def section(self, title, contents): """Format a section with a given heading.""" return self.bold(title) + '\n' + rstrip(self.indent(contents)) + '\n\n' # ---------------------------------------------- type-specific routines def formattree(self, tree, modname, parent=None, prefix=''): """Render in text a class tree as returned by inspect.getclasstree().""" result = '' for entry in tree: if type(entry) is type(()): c, bases = entry result = result + prefix + classname(c, modname) if bases and bases != (parent,): parents = map(lambda c, m=modname: classname(c, m), bases) result = result + '(%s)' % join(parents, ', ') result = result + '\n' elif type(entry) is type([]): result = result + self.formattree( entry, modname, c, prefix + ' ') return result def docmodule(self, object, name=None, mod=None): """Produce text documentation for a given module object.""" name = object.__name__ # ignore the passed-in name synop, desc = splitdoc(getdoc(object)) result = self.section('NAME', name + (synop and ' - ' + synop)) try: all = object.__all__ except AttributeError: all = None try: file = inspect.getabsfile(object) except TypeError: file = '(built-in)' result = result + self.section('FILE', file) docloc = self.getdocloc(object) if docloc is not None: result = result + self.section('MODULE DOCS', docloc) if desc: result = result + self.section('DESCRIPTION', desc) classes = [] for key, value in inspect.getmembers(object, inspect.isclass): # if __all__ exists, believe it. Otherwise use old heuristic. if (all is not None or (inspect.getmodule(value) or object) is object): if visiblename(key, all, object): classes.append((key, value)) funcs = [] for key, value in inspect.getmembers(object, inspect.isroutine): # if __all__ exists, believe it. Otherwise use old heuristic. if (all is not None or inspect.isbuiltin(value) or inspect.getmodule(value) is object): if visiblename(key, all, object): funcs.append((key, value)) data = [] for key, value in inspect.getmembers(object, isdata): if visiblename(key, all, object): data.append((key, value)) modpkgs = [] modpkgs_names = set() if hasattr(object, '__path__'): for importer, modname, ispkg in pkgutil.iter_modules(object.__path__): modpkgs_names.add(modname) if ispkg: modpkgs.append(modname + ' (package)') else: modpkgs.append(modname) modpkgs.sort() result = result + self.section( 'PACKAGE CONTENTS', join(modpkgs, '\n')) # Detect submodules as sometimes created by C extensions submodules = [] for key, value in inspect.getmembers(object, inspect.ismodule): if value.__name__.startswith(name + '.') and key not in modpkgs_names: submodules.append(key) if submodules: submodules.sort() result = result + self.section( 'SUBMODULES', join(submodules, '\n')) if classes: classlist = map(lambda key_value: key_value[1], classes) contents = [self.formattree( inspect.getclasstree(classlist, 1), name)] for key, value in classes: contents.append(self.document(value, key, name)) result = result + self.section('CLASSES', join(contents, '\n')) if funcs: contents = [] for key, value in funcs: contents.append(self.document(value, key, name)) result = result + self.section('FUNCTIONS', join(contents, '\n')) if data: contents = [] for key, value in data: contents.append(self.docother(value, key, name, maxlen=70)) result = result + self.section('DATA', join(contents, '\n')) if hasattr(object, '__version__'): version = str(object.__version__) if version[:11] == '$' + 'Revision: ' and version[-1:] == '$': version = strip(version[11:-1]) result = result + self.section('VERSION', version) if hasattr(object, '__date__'): result = result + self.section('DATE', str(object.__date__)) if hasattr(object, '__author__'): result = result + self.section('AUTHOR', str(object.__author__)) if hasattr(object, '__credits__'): result = result + self.section('CREDITS', str(object.__credits__)) return result def docclass(self, object, name=None, mod=None, *ignored): """Produce text documentation for a given class object.""" realname = object.__name__ name = name or realname bases = object.__bases__ def makename(c, m=object.__module__): return classname(c, m) if name == realname: title = 'class ' + self.bold(realname) else: title = self.bold(name) + ' = class ' + realname if bases: parents = map(makename, bases) title = title + '(%s)' % join(parents, ', ') doc = getdoc(object) contents = doc and [doc + '\n'] or [] push = contents.append # List the mro, if non-trivial. mro = deque(inspect.getmro(object)) if len(mro) > 2: push("Method resolution order:") for base in mro: push(' ' + makename(base)) push('') # Cute little class to pump out a horizontal rule between sections. class HorizontalRule: def __init__(self): self.needone = 0 def maybe(self): if self.needone: push('-' * 70) self.needone = 1 hr = HorizontalRule() def spill(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: try: value = getattr(object, name) except Exception: # Some descriptors may meet a failure in their __get__. # (bug #1785) push(self._docdescriptor(name, value, mod)) else: push(self.document(value, name, mod, object)) return attrs def spilldescriptors(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: push(self._docdescriptor(name, value, mod)) return attrs def spilldata(msg, attrs, predicate): ok, attrs = _split_list(attrs, predicate) if ok: hr.maybe() push(msg) for name, kind, homecls, value in ok: if (hasattr(value, '__call__') or inspect.isdatadescriptor(value)): doc = getdoc(value) else: doc = None push(self.docother(getattr(object, name), name, mod, maxlen=70, doc=doc) + '\n') return attrs attrs = filter(lambda data: visiblename(data[0], obj=object), classify_class_attrs(object)) while attrs: if mro: thisclass = mro.popleft() else: thisclass = attrs[0][2] attrs, inherited = _split_list(attrs, lambda t: t[2] is thisclass) if thisclass is __builtin__.object: attrs = inherited continue elif thisclass is object: tag = "defined here" else: tag = "inherited from %s" % classname(thisclass, object.__module__) # Sort attrs by name. attrs.sort() # Pump out the attrs, segregated by kind. attrs = spill("Methods %s:\n" % tag, attrs, lambda t: t[1] == 'method') attrs = spill("Class methods %s:\n" % tag, attrs, lambda t: t[1] == 'class method') attrs = spill("Static methods %s:\n" % tag, attrs, lambda t: t[1] == 'static method') attrs = spilldescriptors("Data descriptors %s:\n" % tag, attrs, lambda t: t[1] == 'data descriptor') attrs = spilldata("Data and other attributes %s:\n" % tag, attrs, lambda t: t[1] == 'data') assert attrs == [] attrs = inherited contents = '\n'.join(contents) if not contents: return title + '\n' return title + '\n' + self.indent(rstrip(contents), ' | ') + '\n' def formatvalue(self, object): """Format an argument default value as text.""" return '=' + self.repr(object) def docroutine(self, object, name=None, mod=None, cl=None): """Produce text documentation for a function or method object.""" realname = object.__name__ name = name or realname note = '' skipdocs = 0 if inspect.ismethod(object): imclass = object.im_class if cl: if imclass is not cl: note = ' from ' + classname(imclass, mod) else: if object.im_self is not None: note = ' method of %s instance' % classname( object.im_self.__class__, mod) else: note = ' unbound %s method' % classname(imclass,mod) object = object.im_func if name == realname: title = self.bold(realname) else: if (cl and realname in cl.__dict__ and cl.__dict__[realname] is object): skipdocs = 1 title = self.bold(name) + ' = ' + realname if inspect.isfunction(object): args, varargs, varkw, defaults = inspect.getargspec(object) argspec = inspect.formatargspec( args, varargs, varkw, defaults, formatvalue=self.formatvalue) if realname == '': title = self.bold(name) + ' lambda ' argspec = argspec[1:-1] # remove parentheses else: argspec = '(...)' decl = title + argspec + note if skipdocs: return decl + '\n' else: doc = getdoc(object) or '' return decl + '\n' + (doc and rstrip(self.indent(doc)) + '\n') def _docdescriptor(self, name, value, mod): results = [] push = results.append if name: push(self.bold(name)) push('\n') doc = getdoc(value) or '' if doc: push(self.indent(doc)) push('\n') return ''.join(results) def docproperty(self, object, name=None, mod=None, cl=None): """Produce text documentation for a property.""" return self._docdescriptor(name, object, mod) def docdata(self, object, name=None, mod=None, cl=None): """Produce text documentation for a data descriptor.""" return self._docdescriptor(name, object, mod) def docother(self, object, name=None, mod=None, parent=None, maxlen=None, doc=None): """Produce text documentation for a data object.""" repr = self.repr(object) if maxlen: line = (name and name + ' = ' or '') + repr chop = maxlen - len(line) if chop < 0: repr = repr[:chop] + '...' line = (name and self.bold(name) + ' = ' or '') + repr if doc is not None: line += '\n' + self.indent(str(doc)) return line # --------------------------------------------------------- user interfaces def pager(text): """The first time this is called, determine what kind of pager to use.""" global pager pager = getpager() pager(text) def getpager(): """Decide what method to use for paging through text.""" if type(sys.stdout) is not types.FileType: return plainpager if not sys.stdin.isatty() or not sys.stdout.isatty(): return plainpager if 'PAGER' in os.environ: if sys.platform == 'win32': # pipes completely broken in Windows return lambda text: tempfilepager(plain(text), os.environ['PAGER']) elif os.environ.get('TERM') in ('dumb', 'emacs'): return lambda text: pipepager(plain(text), os.environ['PAGER']) else: return lambda text: pipepager(text, os.environ['PAGER']) if os.environ.get('TERM') in ('dumb', 'emacs'): return plainpager if sys.platform == 'win32' or sys.platform.startswith('os2'): return lambda text: tempfilepager(plain(text), 'more <') if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0: return lambda text: pipepager(text, 'less') import tempfile (fd, filename) = tempfile.mkstemp() os.close(fd) try: if hasattr(os, 'system') and os.system('more "%s"' % filename) == 0: return lambda text: pipepager(text, 'more') else: return ttypager finally: os.unlink(filename) def plain(text): """Remove boldface formatting from text.""" return re.sub('.\b', '', text) def pipepager(text, cmd): """Page through text by feeding it to another program.""" pipe = os.popen(cmd, 'w') try: pipe.write(text) pipe.close() except IOError: pass # Ignore broken pipes caused by quitting the pager program. def tempfilepager(text, cmd): """Page through text by invoking a program on a temporary file.""" import tempfile fd, filename = tempfile.mkstemp() file = open(filename, 'w') file.write(text) file.close() try: os.system(cmd + ' "' + filename + '"') finally: os.close(fd) def ttypager(text): """Page through text on a text terminal.""" lines = split(plain(text), '\n') try: import tty fd = sys.stdin.fileno() old = tty.tcgetattr(fd) tty.setcbreak(fd) getchar = lambda: sys.stdin.read(1) except (ImportError, AttributeError): tty = None getchar = lambda: sys.stdin.readline()[:-1][:1] try: r = inc = os.environ.get('LINES', 25) - 1 sys.stdout.write(join(lines[:inc], '\n') + '\n') while lines[r:]: sys.stdout.write('-- more --') sys.stdout.flush() c = getchar() if c in ('q', 'Q'): sys.stdout.write('\r \r') break elif c in ('\r', '\n'): sys.stdout.write('\r \r' + lines[r] + '\n') r = r + 1 continue if c in ('b', 'B', '\x1b'): r = r - inc - inc if r < 0: r = 0 sys.stdout.write('\n' + join(lines[r:r+inc], '\n') + '\n') r = r + inc finally: if tty: tty.tcsetattr(fd, tty.TCSAFLUSH, old) def plainpager(text): """Simply print unformatted text. This is the ultimate fallback.""" sys.stdout.write(plain(text)) def describe(thing): """Produce a short description of the given thing.""" if inspect.ismodule(thing): if thing.__name__ in sys.builtin_module_names: return 'built-in module ' + thing.__name__ if hasattr(thing, '__path__'): return 'package ' + thing.__name__ else: return 'module ' + thing.__name__ if inspect.isbuiltin(thing): return 'built-in function ' + thing.__name__ if inspect.isgetsetdescriptor(thing): return 'getset descriptor %s.%s.%s' % ( thing.__objclass__.__module__, thing.__objclass__.__name__, thing.__name__) if inspect.ismemberdescriptor(thing): return 'member descriptor %s.%s.%s' % ( thing.__objclass__.__module__, thing.__objclass__.__name__, thing.__name__) if inspect.isclass(thing): return 'class ' + thing.__name__ if inspect.isfunction(thing): return 'function ' + thing.__name__ if inspect.ismethod(thing): return 'method ' + thing.__name__ if type(thing) is types.InstanceType: return 'instance of ' + thing.__class__.__name__ return type(thing).__name__ def locate(path, forceload=0): """Locate an object by name or dotted path, importing as necessary.""" parts = [part for part in split(path, '.') if part] module, n = None, 0 while n < len(parts): nextmodule = safeimport(join(parts[:n+1], '.'), forceload) if nextmodule: module, n = nextmodule, n + 1 else: break if module: object = module else: object = __builtin__ for part in parts[n:]: try: object = getattr(object, part) except AttributeError: return None return object # --------------------------------------- interactive interpreter interface text = TextDoc() html = HTMLDoc() class _OldStyleClass: pass _OLD_INSTANCE_TYPE = type(_OldStyleClass()) def resolve(thing, forceload=0): """Given an object or a path to an object, get the object and its name.""" if isinstance(thing, str): object = locate(thing, forceload) if not object: raise ImportError, 'no Python documentation found for %r' % thing return object, thing else: name = getattr(thing, '__name__', None) return thing, name if isinstance(name, str) else None def render_doc(thing, title='Python Library Documentation: %s', forceload=0): """Render text documentation, given an object or a path to an object.""" object, name = resolve(thing, forceload) desc = describe(object) module = inspect.getmodule(object) if name and '.' in name: desc += ' in ' + name[:name.rfind('.')] elif module and module is not object: desc += ' in module ' + module.__name__ if type(object) is _OLD_INSTANCE_TYPE: # If the passed object is an instance of an old-style class, # document its available methods instead of its value. object = object.__class__ elif not (inspect.ismodule(object) or inspect.isclass(object) or inspect.isroutine(object) or inspect.isgetsetdescriptor(object) or inspect.ismemberdescriptor(object) or isinstance(object, property)): # If the passed object is a piece of data or an instance, # document its available methods instead of its value. object = type(object) desc += ' object' return title % desc + '\n\n' + text.document(object, name) def doc(thing, title='Python Library Documentation: %s', forceload=0): """Display text documentation, given an object or a path to an object.""" try: pager(render_doc(thing, title, forceload)) except (ImportError, ErrorDuringImport), value: print value def writedoc(thing, forceload=0): """Write HTML documentation to a file in the current directory.""" try: object, name = resolve(thing, forceload) page = html.page(describe(object), html.document(object, name)) file = open(name + '.html', 'w') file.write(page) file.close() print 'wrote', name + '.html' except (ImportError, ErrorDuringImport), value: print value def writedocs(dir, pkgpath='', done=None): """Write out HTML documentation for all modules in a directory tree.""" if done is None: done = {} for importer, modname, ispkg in pkgutil.walk_packages([dir], pkgpath): writedoc(modname) return class Helper: # These dictionaries map a topic name to either an alias, or a tuple # (label, seealso-items). The "label" is the label of the corresponding # section in the .rst file under Doc/ and an index into the dictionary # in pydoc_data/topics.py. # # CAUTION: if you change one of these dictionaries, be sure to adapt the # list of needed labels in Doc/tools/sphinxext/pyspecific.py and # regenerate the pydoc_data/topics.py file by running # make pydoc-topics # in Doc/ and copying the output file into the Lib/ directory. keywords = { 'and': 'BOOLEAN', 'as': 'with', 'assert': ('assert', ''), 'break': ('break', 'while for'), 'class': ('class', 'CLASSES SPECIALMETHODS'), 'continue': ('continue', 'while for'), 'def': ('function', ''), 'del': ('del', 'BASICMETHODS'), 'elif': 'if', 'else': ('else', 'while for'), 'except': 'try', 'exec': ('exec', ''), 'finally': 'try', 'for': ('for', 'break continue while'), 'from': 'import', 'global': ('global', 'NAMESPACES'), 'if': ('if', 'TRUTHVALUE'), 'import': ('import', 'MODULES'), 'in': ('in', 'SEQUENCEMETHODS2'), 'is': 'COMPARISON', 'lambda': ('lambda', 'FUNCTIONS'), 'not': 'BOOLEAN', 'or': 'BOOLEAN', 'pass': ('pass', ''), 'print': ('print', ''), 'raise': ('raise', 'EXCEPTIONS'), 'return': ('return', 'FUNCTIONS'), 'try': ('try', 'EXCEPTIONS'), 'while': ('while', 'break continue if TRUTHVALUE'), 'with': ('with', 'CONTEXTMANAGERS EXCEPTIONS yield'), 'yield': ('yield', ''), } # Either add symbols to this dictionary or to the symbols dictionary # directly: Whichever is easier. They are merged later. _symbols_inverse = { 'STRINGS' : ("'", "'''", "r'", "u'", '"""', '"', 'r"', 'u"'), 'OPERATORS' : ('+', '-', '*', '**', '/', '//', '%', '<<', '>>', '&', '|', '^', '~', '<', '>', '<=', '>=', '==', '!=', '<>'), 'COMPARISON' : ('<', '>', '<=', '>=', '==', '!=', '<>'), 'UNARY' : ('-', '~'), 'AUGMENTEDASSIGNMENT' : ('+=', '-=', '*=', '/=', '%=', '&=', '|=', '^=', '<<=', '>>=', '**=', '//='), 'BITWISE' : ('<<', '>>', '&', '|', '^', '~'), 'COMPLEX' : ('j', 'J') } symbols = { '%': 'OPERATORS FORMATTING', '**': 'POWER', ',': 'TUPLES LISTS FUNCTIONS', '.': 'ATTRIBUTES FLOAT MODULES OBJECTS', '...': 'ELLIPSIS', ':': 'SLICINGS DICTIONARYLITERALS', '@': 'def class', '\\': 'STRINGS', '_': 'PRIVATENAMES', '__': 'PRIVATENAMES SPECIALMETHODS', '`': 'BACKQUOTES', '(': 'TUPLES FUNCTIONS CALLS', ')': 'TUPLES FUNCTIONS CALLS', '[': 'LISTS SUBSCRIPTS SLICINGS', ']': 'LISTS SUBSCRIPTS SLICINGS' } for topic, symbols_ in _symbols_inverse.iteritems(): for symbol in symbols_: topics = symbols.get(symbol, topic) if topic not in topics: topics = topics + ' ' + topic symbols[symbol] = topics topics = { 'TYPES': ('types', 'STRINGS UNICODE NUMBERS SEQUENCES MAPPINGS ' 'FUNCTIONS CLASSES MODULES FILES inspect'), 'STRINGS': ('strings', 'str UNICODE SEQUENCES STRINGMETHODS FORMATTING ' 'TYPES'), 'STRINGMETHODS': ('string-methods', 'STRINGS FORMATTING'), 'FORMATTING': ('formatstrings', 'OPERATORS'), 'UNICODE': ('strings', 'encodings unicode SEQUENCES STRINGMETHODS ' 'FORMATTING TYPES'), 'NUMBERS': ('numbers', 'INTEGER FLOAT COMPLEX TYPES'), 'INTEGER': ('integers', 'int range'), 'FLOAT': ('floating', 'float math'), 'COMPLEX': ('imaginary', 'complex cmath'), 'SEQUENCES': ('typesseq', 'STRINGMETHODS FORMATTING xrange LISTS'), 'MAPPINGS': 'DICTIONARIES', 'FUNCTIONS': ('typesfunctions', 'def TYPES'), 'METHODS': ('typesmethods', 'class def CLASSES TYPES'), 'CODEOBJECTS': ('bltin-code-objects', 'compile FUNCTIONS TYPES'), 'TYPEOBJECTS': ('bltin-type-objects', 'types TYPES'), 'FRAMEOBJECTS': 'TYPES', 'TRACEBACKS': 'TYPES', 'NONE': ('bltin-null-object', ''), 'ELLIPSIS': ('bltin-ellipsis-object', 'SLICINGS'), 'FILES': ('bltin-file-objects', ''), 'SPECIALATTRIBUTES': ('specialattrs', ''), 'CLASSES': ('types', 'class SPECIALMETHODS PRIVATENAMES'), 'MODULES': ('typesmodules', 'import'), 'PACKAGES': 'import', 'EXPRESSIONS': ('operator-summary', 'lambda or and not in is BOOLEAN ' 'COMPARISON BITWISE SHIFTING BINARY FORMATTING POWER ' 'UNARY ATTRIBUTES SUBSCRIPTS SLICINGS CALLS TUPLES ' 'LISTS DICTIONARIES BACKQUOTES'), 'OPERATORS': 'EXPRESSIONS', 'PRECEDENCE': 'EXPRESSIONS', 'OBJECTS': ('objects', 'TYPES'), 'SPECIALMETHODS': ('specialnames', 'BASICMETHODS ATTRIBUTEMETHODS ' 'CALLABLEMETHODS SEQUENCEMETHODS1 MAPPINGMETHODS ' 'SEQUENCEMETHODS2 NUMBERMETHODS CLASSES'), 'BASICMETHODS': ('customization', 'cmp hash repr str SPECIALMETHODS'), 'ATTRIBUTEMETHODS': ('attribute-access', 'ATTRIBUTES SPECIALMETHODS'), 'CALLABLEMETHODS': ('callable-types', 'CALLS SPECIALMETHODS'), 'SEQUENCEMETHODS1': ('sequence-types', 'SEQUENCES SEQUENCEMETHODS2 ' 'SPECIALMETHODS'), 'SEQUENCEMETHODS2': ('sequence-methods', 'SEQUENCES SEQUENCEMETHODS1 ' 'SPECIALMETHODS'), 'MAPPINGMETHODS': ('sequence-types', 'MAPPINGS SPECIALMETHODS'), 'NUMBERMETHODS': ('numeric-types', 'NUMBERS AUGMENTEDASSIGNMENT ' 'SPECIALMETHODS'), 'EXECUTION': ('execmodel', 'NAMESPACES DYNAMICFEATURES EXCEPTIONS'), 'NAMESPACES': ('naming', 'global ASSIGNMENT DELETION DYNAMICFEATURES'), 'DYNAMICFEATURES': ('dynamic-features', ''), 'SCOPING': 'NAMESPACES', 'FRAMES': 'NAMESPACES', 'EXCEPTIONS': ('exceptions', 'try except finally raise'), 'COERCIONS': ('coercion-rules','CONVERSIONS'), 'CONVERSIONS': ('conversions', 'COERCIONS'), 'IDENTIFIERS': ('identifiers', 'keywords SPECIALIDENTIFIERS'), 'SPECIALIDENTIFIERS': ('id-classes', ''), 'PRIVATENAMES': ('atom-identifiers', ''), 'LITERALS': ('atom-literals', 'STRINGS BACKQUOTES NUMBERS ' 'TUPLELITERALS LISTLITERALS DICTIONARYLITERALS'), 'TUPLES': 'SEQUENCES', 'TUPLELITERALS': ('exprlists', 'TUPLES LITERALS'), 'LISTS': ('typesseq-mutable', 'LISTLITERALS'), 'LISTLITERALS': ('lists', 'LISTS LITERALS'), 'DICTIONARIES': ('typesmapping', 'DICTIONARYLITERALS'), 'DICTIONARYLITERALS': ('dict', 'DICTIONARIES LITERALS'), 'BACKQUOTES': ('string-conversions', 'repr str STRINGS LITERALS'), 'ATTRIBUTES': ('attribute-references', 'getattr hasattr setattr ' 'ATTRIBUTEMETHODS'), 'SUBSCRIPTS': ('subscriptions', 'SEQUENCEMETHODS1'), 'SLICINGS': ('slicings', 'SEQUENCEMETHODS2'), 'CALLS': ('calls', 'EXPRESSIONS'), 'POWER': ('power', 'EXPRESSIONS'), 'UNARY': ('unary', 'EXPRESSIONS'), 'BINARY': ('binary', 'EXPRESSIONS'), 'SHIFTING': ('shifting', 'EXPRESSIONS'), 'BITWISE': ('bitwise', 'EXPRESSIONS'), 'COMPARISON': ('comparisons', 'EXPRESSIONS BASICMETHODS'), 'BOOLEAN': ('booleans', 'EXPRESSIONS TRUTHVALUE'), 'ASSERTION': 'assert', 'ASSIGNMENT': ('assignment', 'AUGMENTEDASSIGNMENT'), 'AUGMENTEDASSIGNMENT': ('augassign', 'NUMBERMETHODS'), 'DELETION': 'del', 'PRINTING': 'print', 'RETURNING': 'return', 'IMPORTING': 'import', 'CONDITIONAL': 'if', 'LOOPING': ('compound', 'for while break continue'), 'TRUTHVALUE': ('truth', 'if while and or not BASICMETHODS'), 'DEBUGGING': ('debugger', 'pdb'), 'CONTEXTMANAGERS': ('context-managers', 'with'), } def __init__(self, input=None, output=None): self._input = input self._output = output input = property(lambda self: self._input or sys.stdin) output = property(lambda self: self._output or sys.stdout) def __repr__(self): if inspect.stack()[1][3] == '?': self() return '' return '' _GoInteractive = object() def __call__(self, request=_GoInteractive): if request is not self._GoInteractive: self.help(request) else: self.intro() self.interact() self.output.write(''' You are now leaving help and returning to the Python interpreter. If you want to ask for help on a particular object directly from the interpreter, you can type "help(object)". Executing "help('string')" has the same effect as typing a particular string at the help> prompt. ''') def interact(self): self.output.write('\n') while True: try: request = self.getline('help> ') if not request: break except (KeyboardInterrupt, EOFError): break request = strip(replace(request, '"', '', "'", '')) if lower(request) in ('q', 'quit'): break self.help(request) def getline(self, prompt): """Read one line, using raw_input when available.""" if self.input is sys.stdin: return raw_input(prompt) else: self.output.write(prompt) self.output.flush() return self.input.readline() def help(self, request): if type(request) is type(''): request = request.strip() if request == 'help': self.intro() elif request == 'keywords': self.listkeywords() elif request == 'symbols': self.listsymbols() elif request == 'topics': self.listtopics() elif request == 'modules': self.listmodules() elif request[:8] == 'modules ': self.listmodules(split(request)[1]) elif request in self.symbols: self.showsymbol(request) elif request in self.keywords: self.showtopic(request) elif request in self.topics: self.showtopic(request) elif request: doc(request, 'Help on %s:') elif isinstance(request, Helper): self() else: doc(request, 'Help on %s:') self.output.write('\n') def intro(self): self.output.write(''' Welcome to Python %s! This is the online help utility. If this is your first time using Python, you should definitely check out the tutorial on the Internet at http://docs.python.org/%s/tutorial/. Enter the name of any module, keyword, or topic to get help on writing Python programs and using Python modules. To quit this help utility and return to the interpreter, just type "quit". To get a list of available modules, keywords, or topics, type "modules", "keywords", or "topics". Each module also comes with a one-line summary of what it does; to list the modules whose summaries contain a given word such as "spam", type "modules spam". ''' % tuple([sys.version[:3]]*2)) def list(self, items, columns=4, width=80): items = items[:] items.sort() colw = width / columns rows = (len(items) + columns - 1) / columns for row in range(rows): for col in range(columns): i = col * rows + row if i < len(items): self.output.write(items[i]) if col < columns - 1: self.output.write(' ' + ' ' * (colw-1 - len(items[i]))) self.output.write('\n') def listkeywords(self): self.output.write(''' Here is a list of the Python keywords. Enter any keyword to get more help. ''') self.list(self.keywords.keys()) def listsymbols(self): self.output.write(''' Here is a list of the punctuation symbols which Python assigns special meaning to. Enter any symbol to get more help. ''') self.list(self.symbols.keys()) def listtopics(self): self.output.write(''' Here is a list of available topics. Enter any topic name to get more help. ''') self.list(self.topics.keys()) def showtopic(self, topic, more_xrefs=''): try: import pydoc_data.topics except ImportError: self.output.write(''' Sorry, topic and keyword documentation is not available because the module "pydoc_data.topics" could not be found. ''') return target = self.topics.get(topic, self.keywords.get(topic)) if not target: self.output.write('no documentation found for %s\n' % repr(topic)) return if type(target) is type(''): return self.showtopic(target, more_xrefs) label, xrefs = target try: doc = pydoc_data.topics.topics[label] except KeyError: self.output.write('no documentation found for %s\n' % repr(topic)) return pager(strip(doc) + '\n') if more_xrefs: xrefs = (xrefs or '') + ' ' + more_xrefs if xrefs: import StringIO, formatter buffer = StringIO.StringIO() formatter.DumbWriter(buffer).send_flowing_data( 'Related help topics: ' + join(split(xrefs), ', ') + '\n') self.output.write('\n%s\n' % buffer.getvalue()) def showsymbol(self, symbol): target = self.symbols[symbol] topic, _, xrefs = target.partition(' ') self.showtopic(topic, xrefs) def listmodules(self, key=''): if key: self.output.write(''' Here is a list of matching modules. Enter any module name to get more help. ''') apropos(key) else: self.output.write(''' Please wait a moment while I gather a list of all available modules... ''') modules = {} def callback(path, modname, desc, modules=modules): if modname and modname[-9:] == '.__init__': modname = modname[:-9] + ' (package)' if find(modname, '.') < 0: modules[modname] = 1 def onerror(modname): callback(None, modname, None) ModuleScanner().run(callback, onerror=onerror) self.list(modules.keys()) self.output.write(''' Enter any module name to get more help. Or, type "modules spam" to search for modules whose descriptions contain the word "spam". ''') help = Helper() class Scanner: """A generic tree iterator.""" def __init__(self, roots, children, descendp): self.roots = roots[:] self.state = [] self.children = children self.descendp = descendp def next(self): if not self.state: if not self.roots: return None root = self.roots.pop(0) self.state = [(root, self.children(root))] node, children = self.state[-1] if not children: self.state.pop() return self.next() child = children.pop(0) if self.descendp(child): self.state.append((child, self.children(child))) return child class ModuleScanner: """An interruptible scanner that searches module synopses.""" def run(self, callback, key=None, completer=None, onerror=None): if key: key = lower(key) self.quit = False seen = {} for modname in sys.builtin_module_names: if modname != '__main__': seen[modname] = 1 if key is None: callback(None, modname, '') else: desc = split(__import__(modname).__doc__ or '', '\n')[0] if find(lower(modname + ' - ' + desc), key) >= 0: callback(None, modname, desc) for importer, modname, ispkg in pkgutil.walk_packages(onerror=onerror): if self.quit: break if key is None: callback(None, modname, '') else: loader = importer.find_module(modname) if hasattr(loader,'get_source'): import StringIO desc = source_synopsis( StringIO.StringIO(loader.get_source(modname)) ) or '' if hasattr(loader,'get_filename'): path = loader.get_filename(modname) else: path = None else: module = loader.load_module(modname) desc = (module.__doc__ or '').splitlines()[0] path = getattr(module,'__file__',None) if find(lower(modname + ' - ' + desc), key) >= 0: callback(path, modname, desc) if completer: completer() def apropos(key): """Print all the one-line module summaries that contain a substring.""" def callback(path, modname, desc): if modname[-9:] == '.__init__': modname = modname[:-9] + ' (package)' print modname, desc and '- ' + desc def onerror(modname): pass with warnings.catch_warnings(): warnings.filterwarnings('ignore') # ignore problems during import ModuleScanner().run(callback, key, onerror=onerror) # --------------------------------------------------- web browser interface def serve(port, callback=None, completer=None): import BaseHTTPServer, mimetools, select # Patch up mimetools.Message so it doesn't break if rfc822 is reloaded. class Message(mimetools.Message): def __init__(self, fp, seekable=1): Message = self.__class__ Message.__bases__[0].__bases__[0].__init__(self, fp, seekable) self.encodingheader = self.getheader('content-transfer-encoding') self.typeheader = self.getheader('content-type') self.parsetype() self.parseplist() class DocHandler(BaseHTTPServer.BaseHTTPRequestHandler): def send_document(self, title, contents): try: self.send_response(200) self.send_header('Content-Type', 'text/html') self.end_headers() self.wfile.write(html.page(title, contents)) except IOError: pass def do_GET(self): path = self.path if path[-5:] == '.html': path = path[:-5] if path[:1] == '/': path = path[1:] if path and path != '.': try: obj = locate(path, forceload=1) except ErrorDuringImport, value: self.send_document(path, html.escape(str(value))) return if obj: self.send_document(describe(obj), html.document(obj, path)) else: self.send_document(path, 'no Python documentation found for %s' % repr(path)) else: heading = html.heading( 'Python: Index of Modules', '#ffffff', '#7799ee') def bltinlink(name): return '%s' % (name, name) names = filter(lambda x: x != '__main__', sys.builtin_module_names) contents = html.multicolumn(names, bltinlink) indices = ['

' + html.bigsection( 'Built-in Modules', '#ffffff', '#ee77aa', contents)] seen = {} for dir in sys.path: indices.append(html.index(dir, seen)) contents = heading + join(indices) + '''

pydoc by Ka-Ping Yee <ping@lfw.org>''' self.send_document('Index of Modules', contents) def log_message(self, *args): pass class DocServer(BaseHTTPServer.HTTPServer): def __init__(self, port, callback): host = '' self.address = (host, port) self.url = 'http://%s:%d/' % (host, port) self.callback = callback self.base.__init__(self, self.address, self.handler) def serve_until_quit(self): import select self.quit = False while not self.quit: rd, wr, ex = select.select([self.socket.fileno()], [], [], 1) if rd: self.handle_request() def server_activate(self): self.base.server_activate(self) if self.callback: self.callback(self) DocServer.base = BaseHTTPServer.HTTPServer DocServer.handler = DocHandler DocHandler.MessageClass = Message try: try: DocServer(port, callback).serve_until_quit() except (KeyboardInterrupt, select.error): pass finally: if completer: completer() # ----------------------------------------------------- graphical interface def gui(): """Graphical interface (starts web server and pops up a control window).""" class GUI: def __init__(self, window, port=7464): self.window = window self.server = None self.scanner = None import Tkinter self.server_frm = Tkinter.Frame(window) self.title_lbl = Tkinter.Label(self.server_frm, text='Starting server...\n ') self.open_btn = Tkinter.Button(self.server_frm, text='open browser', command=self.open, state='disabled') self.quit_btn = Tkinter.Button(self.server_frm, text='quit serving', command=self.quit, state='disabled') self.search_frm = Tkinter.Frame(window) self.search_lbl = Tkinter.Label(self.search_frm, text='Search for') self.search_ent = Tkinter.Entry(self.search_frm) self.search_ent.bind('', self.search) self.stop_btn = Tkinter.Button(self.search_frm, text='stop', pady=0, command=self.stop, state='disabled') if sys.platform == 'win32': # Trying to hide and show this button crashes under Windows. self.stop_btn.pack(side='right') self.window.title('pydoc') self.window.protocol('WM_DELETE_WINDOW', self.quit) self.title_lbl.pack(side='top', fill='x') self.open_btn.pack(side='left', fill='x', expand=1) self.quit_btn.pack(side='right', fill='x', expand=1) self.server_frm.pack(side='top', fill='x') self.search_lbl.pack(side='left') self.search_ent.pack(side='right', fill='x', expand=1) self.search_frm.pack(side='top', fill='x') self.search_ent.focus_set() font = ('helvetica', sys.platform == 'win32' and 8 or 10) self.result_lst = Tkinter.Listbox(window, font=font, height=6) self.result_lst.bind('', self.select) self.result_lst.bind('', self.goto) self.result_scr = Tkinter.Scrollbar(window, orient='vertical', command=self.result_lst.yview) self.result_lst.config(yscrollcommand=self.result_scr.set) self.result_frm = Tkinter.Frame(window) self.goto_btn = Tkinter.Button(self.result_frm, text='go to selected', command=self.goto) self.hide_btn = Tkinter.Button(self.result_frm, text='hide results', command=self.hide) self.goto_btn.pack(side='left', fill='x', expand=1) self.hide_btn.pack(side='right', fill='x', expand=1) self.window.update() self.minwidth = self.window.winfo_width() self.minheight = self.window.winfo_height() self.bigminheight = (self.server_frm.winfo_reqheight() + self.search_frm.winfo_reqheight() + self.result_lst.winfo_reqheight() + self.result_frm.winfo_reqheight()) self.bigwidth, self.bigheight = self.minwidth, self.bigminheight self.expanded = 0 self.window.wm_geometry('%dx%d' % (self.minwidth, self.minheight)) self.window.wm_minsize(self.minwidth, self.minheight) self.window.tk.willdispatch() import threading threading.Thread( target=serve, args=(port, self.ready, self.quit)).start() def ready(self, server): self.server = server self.title_lbl.config( text='Python documentation server at\n' + server.url) self.open_btn.config(state='normal') self.quit_btn.config(state='normal') def open(self, event=None, url=None): url = url or self.server.url try: import webbrowser webbrowser.open(url) except ImportError: # pre-webbrowser.py compatibility if sys.platform == 'win32': os.system('start "%s"' % url) else: rc = os.system('netscape -remote "openURL(%s)" &' % url) if rc: os.system('netscape "%s" &' % url) def quit(self, event=None): if self.server: self.server.quit = 1 self.window.quit() def search(self, event=None): key = self.search_ent.get() self.stop_btn.pack(side='right') self.stop_btn.config(state='normal') self.search_lbl.config(text='Searching for "%s"...' % key) self.search_ent.forget() self.search_lbl.pack(side='left') self.result_lst.delete(0, 'end') self.goto_btn.config(state='disabled') self.expand() import threading if self.scanner: self.scanner.quit = 1 self.scanner = ModuleScanner() threading.Thread(target=self.scanner.run, args=(self.update, key, self.done)).start() def update(self, path, modname, desc): if modname[-9:] == '.__init__': modname = modname[:-9] + ' (package)' self.result_lst.insert('end', modname + ' - ' + (desc or '(no description)')) def stop(self, event=None): if self.scanner: self.scanner.quit = 1 self.scanner = None def done(self): self.scanner = None self.search_lbl.config(text='Search for') self.search_lbl.pack(side='left') self.search_ent.pack(side='right', fill='x', expand=1) if sys.platform != 'win32': self.stop_btn.forget() self.stop_btn.config(state='disabled') def select(self, event=None): self.goto_btn.config(state='normal') def goto(self, event=None): selection = self.result_lst.curselection() if selection: modname = split(self.result_lst.get(selection[0]))[0] self.open(url=self.server.url + modname + '.html') def collapse(self): if not self.expanded: return self.result_frm.forget() self.result_scr.forget() self.result_lst.forget() self.bigwidth = self.window.winfo_width() self.bigheight = self.window.winfo_height() self.window.wm_geometry('%dx%d' % (self.minwidth, self.minheight)) self.window.wm_minsize(self.minwidth, self.minheight) self.expanded = 0 def expand(self): if self.expanded: return self.result_frm.pack(side='bottom', fill='x') self.result_scr.pack(side='right', fill='y') self.result_lst.pack(side='top', fill='both', expand=1) self.window.wm_geometry('%dx%d' % (self.bigwidth, self.bigheight)) self.window.wm_minsize(self.minwidth, self.bigminheight) self.expanded = 1 def hide(self, event=None): self.stop() self.collapse() import Tkinter try: root = Tkinter.Tk() # Tk will crash if pythonw.exe has an XP .manifest # file and the root has is not destroyed explicitly. # If the problem is ever fixed in Tk, the explicit # destroy can go. try: gui = GUI(root) root.mainloop() finally: root.destroy() except KeyboardInterrupt: pass # -------------------------------------------------- command-line interface def ispath(x): return isinstance(x, str) and find(x, os.sep) >= 0 def cli(): """Command-line interface (looks at sys.argv to decide what to do).""" import getopt class BadUsage: pass # Scripts don't get the current directory in their path by default # unless they are run with the '-m' switch if '' not in sys.path: scriptdir = os.path.dirname(sys.argv[0]) if scriptdir in sys.path: sys.path.remove(scriptdir) sys.path.insert(0, '.') try: opts, args = getopt.getopt(sys.argv[1:], 'gk:p:w') writing = 0 for opt, val in opts: if opt == '-g': gui() return if opt == '-k': apropos(val) return if opt == '-p': try: port = int(val) except ValueError: raise BadUsage def ready(server): print 'pydoc server ready at %s' % server.url def stopped(): print 'pydoc server stopped' serve(port, ready, stopped) return if opt == '-w': writing = 1 if not args: raise BadUsage for arg in args: if ispath(arg) and not os.path.exists(arg): print 'file %r does not exist' % arg break try: if ispath(arg) and os.path.isfile(arg): arg = importfile(arg) if writing: if ispath(arg) and os.path.isdir(arg): writedocs(arg) else: writedoc(arg) else: help.help(arg) except ErrorDuringImport, value: print value except (getopt.error, BadUsage): cmd = os.path.basename(sys.argv[0]) print """pydoc - the Python documentation tool %s ... Show text documentation on something. may be the name of a Python keyword, topic, function, module, or package, or a dotted reference to a class or function within a module or module in a package. If contains a '%s', it is used as the path to a Python source file to document. If name is 'keywords', 'topics', or 'modules', a listing of these things is displayed. %s -k Search for a keyword in the synopsis lines of all available modules. %s -p Start an HTTP server on the given port on the local machine. %s -g Pop up a graphical interface for finding and serving documentation. %s -w ... Write out the HTML documentation for a module to a file in the current directory. If contains a '%s', it is treated as a filename; if it names a directory, documentation is written for all the contents. """ % (cmd, os.sep, cmd, cmd, cmd, cmd, os.sep) if __name__ == '__main__': cli() ================================================ FILE: Utils/HandlerUtil.py ================================================ # # Handler library for Linux IaaS # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "", "configFolder": "", "statusFolder": "", "heartbeatFile": "", } }] Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Example Status Report: [{"version":"1.0","timestampUTC":"2014-05-29T04:20:13Z","status":{"name":"Chef Extension Handler","operation":"chef-client-run","status":"success","code":0,"formattedMessage":{"lang":"en-US","message":"Chef-client run success"}}}] """ import os import os.path import sys import base64 import json import time import re import subprocess # imp was deprecated in python 3.12 if sys.version_info >= (3, 12): import importlib else: import imp from xml.etree import ElementTree from os.path import join from Utils.WAAgentUtil import waagent from waagent import LoggerInit DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ" MANIFEST_XML = "manifest.xml" class HandlerContext: def __init__(self, name): self._name = name self._version = '0.0' self._config_dir = None self._log_dir = None self._log_file = None self._status_dir = None self._heartbeat_file = None self._seq_no = -1 self._status_file = None self._settings_file = None self._config = None return class HandlerUtility: def __init__(self, log, error, s_name=None, l_name=None, extension_version=None, logFileName='extension.log', console_logger=None, file_logger=None): self._log = log self._log_to_con = console_logger self._log_to_file = file_logger self._error = error self._logFileName = logFileName if s_name is None or l_name is None or extension_version is None: (l_name, s_name, extension_version) = self._get_extension_info() self._short_name = s_name self._extension_version = extension_version self._log_prefix = '[%s-%s] ' % (l_name, extension_version) def get_extension_version(self): return self._extension_version def _get_log_prefix(self): return self._log_prefix def _get_extension_info(self): if os.path.isfile(MANIFEST_XML): return self._get_extension_info_manifest() ext_dir = os.path.basename(os.getcwd()) (long_name, version) = ext_dir.split('-') short_name = long_name.split('.')[-1] return long_name, short_name, version def _get_extension_info_manifest(self): with open(MANIFEST_XML) as fh: doc = ElementTree.parse(fh) namespace = doc.find('{http://schemas.microsoft.com/windowsazure}ProviderNameSpace').text short_name = doc.find('{http://schemas.microsoft.com/windowsazure}Type').text version = doc.find('{http://schemas.microsoft.com/windowsazure}Version').text long_name = "%s.%s" % (namespace, short_name) return (long_name, short_name, version) def _get_current_seq_no(self, config_folder): seq_no = -1 cur_seq_no = -1 freshest_time = None for subdir, dirs, files in os.walk(config_folder): for file in files: try: cur_seq_no = int(os.path.basename(file).split('.')[0]) if (freshest_time == None): freshest_time = os.path.getmtime(join(config_folder, file)) seq_no = cur_seq_no else: current_file_m_time = os.path.getmtime(join(config_folder, file)) if (current_file_m_time > freshest_time): freshest_time = current_file_m_time seq_no = cur_seq_no except ValueError: continue return seq_no def log(self, message): self._log(self._get_log_prefix() + message) def log_to_console(self, message): if self._log_to_con is not None: self._log_to_con(self._get_log_prefix() + message) else: self.error("Unable to log to console, console log method not set") def log_to_file(self, message): if self._log_to_file is not None: self._log_to_file(self._get_log_prefix() + message) else: self.error("Unable to log to file, file log method not set") def error(self, message): self._error(self._get_log_prefix() + message) @staticmethod def redact_protected_settings(content): redacted_tmp = re.sub(r'"protectedSettings":\s*"[^"]+=="', '"protectedSettings": "*** REDACTED ***"', content) redacted = re.sub(r'"protectedSettingsCertThumbprint":\s*"[^"]+"', '"protectedSettingsCertThumbprint": "*** REDACTED ***"', redacted_tmp) return redacted def _parse_config(self, ctxt): config = None try: config = json.loads(ctxt) except: self.error('JSON exception decoding ' + HandlerUtility.redact_protected_settings(ctxt)) if config is None: self.error("JSON error processing settings file:" + HandlerUtility.redact_protected_settings(ctxt)) else: handlerSettings = config['runtimeSettings'][0]['handlerSettings'] if 'protectedSettings' in handlerSettings and \ 'protectedSettingsCertThumbprint' in handlerSettings and \ handlerSettings['protectedSettings'] is not None and \ handlerSettings["protectedSettingsCertThumbprint"] is not None: protectedSettings = handlerSettings['protectedSettings'] thumb = handlerSettings['protectedSettingsCertThumbprint'] cert = waagent.LibDir + '/' + thumb + '.crt' pkey = waagent.LibDir + '/' + thumb + '.prv' unencodedSettings = base64.standard_b64decode(protectedSettings) # FIPS 140-3: use 'openssl cms' (supports AES256 & DES_EDE3_CBC) with fallback to legacy 'openssl smime' cms_cmd = 'openssl cms -inform DER -decrypt -recip {0} -inkey {1}'.format(cert,pkey) smime_cmd = 'openssl smime -inform DER -decrypt -recip {0} -inkey {1}'.format(cert,pkey) protected_settings_str = '' for decrypt_cmd in [cms_cmd, smime_cmd]: try: # waagent.RunSendStdin returns a tuple (return code, stdout) output = waagent.RunSendStdin(decrypt_cmd, unencodedSettings) if output and output[0] == 0 and output[1]: protected_settings_str = output[1] if decrypt_cmd == cms_cmd: self.log('Decrypted protectedSettings using openssl cms.') else: self.log('Decrypted protectedSettings using openssl smime fallback.') break else: rc = output[0] if output else 'N/A' self.log('Attempt to decrypt protectedSettings with "{0}" failed (rc={1}).'.format(decrypt_cmd, rc)) except OSError: pass jctxt = '' try: jctxt = json.loads(protected_settings_str) except: self.error('JSON exception decoding ' + HandlerUtility.redact_protected_settings(protected_settings_str)) handlerSettings['protectedSettings']=jctxt self.log('Config decoded correctly.') return config def do_parse_context(self, operation): _context = self.try_parse_context() if not _context: self.do_exit(1, operation, 'error', '1', operation + ' Failed') return _context def try_parse_context(self): self._context = HandlerContext(self._short_name) handler_env = None config = None ctxt = None code = 0 # get the HandlerEnvironment.json. According to the extension handler spec, it is always in the ./ directory self.log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("Unable to locate " + handler_env_file) return None ctxt = waagent.GetFileContents(handler_env_file) if ctxt == None: self.error("Unable to read " + handler_env_file) try: handler_env = json.loads(ctxt) except: pass if handler_env == None: self.log("JSON error processing " + handler_env_file) return None if type(handler_env) == list: handler_env = handler_env[0] self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context._log_dir = handler_env['handlerEnvironment']['logFolder'] self._context._log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'], self._logFileName) self._change_log_file() self._context._status_dir = handler_env['handlerEnvironment']['statusFolder'] self._context._heartbeat_file = handler_env['handlerEnvironment']['heartbeatFile'] self._context._seq_no = self._get_current_seq_no(self._context._config_dir) if self._context._seq_no < 0: self.error("Unable to locate a .settings file!") return None self._context._seq_no = str(self._context._seq_no) self.log('sequence number is ' + self._context._seq_no) self._context._status_file = os.path.join(self._context._status_dir, self._context._seq_no + '.status') self._context._settings_file = os.path.join(self._context._config_dir, self._context._seq_no + '.settings') self.log("setting file path is" + self._context._settings_file) ctxt = None ctxt = waagent.GetFileContents(self._context._settings_file) if ctxt == None: error_msg = 'Unable to read ' + self._context._settings_file + '. ' self.error(error_msg) return None self.log("JSON config: " + HandlerUtility.redact_protected_settings(ctxt)) self._context._config = self._parse_config(ctxt) return self._context def _change_log_file(self): self.log("Change log file to " + self._context._log_file) LoggerInit(self._context._log_file, '/dev/stdout') self._log = waagent.Log self._error = waagent.Error def set_verbose_log(self, verbose): if (verbose == "1" or verbose == 1): self.log("Enable verbose log") LoggerInit(self._context._log_file, '/dev/stdout', verbose=True) else: self.log("Disable verbose log") LoggerInit(self._context._log_file, '/dev/stdout', verbose=False) def is_seq_smaller(self): return int(self._context._seq_no) <= self._get_most_recent_seq() def save_seq(self): self._set_most_recent_seq(self._context._seq_no) self.log("set most recent sequence number to " + self._context._seq_no) def exit_if_enabled(self, remove_protected_settings=False): self.exit_if_seq_smaller(remove_protected_settings) def exit_if_seq_smaller(self, remove_protected_settings): if(self.is_seq_smaller()): self.log("Current sequence number, " + self._context._seq_no + ", is not greater than the sequence number of the most recent executed configuration. Exiting...") sys.exit(0) self.save_seq() if remove_protected_settings: self.scrub_settings_file() def _get_most_recent_seq(self): if (os.path.isfile('mrseq')): seq = waagent.GetFileContents('mrseq') if (seq): return int(seq) return -1 def is_current_config_seq_greater_inused(self): return int(self._context._seq_no) > self._get_most_recent_seq() def get_inused_config_seq(self): return self._get_most_recent_seq() def set_inused_config_seq(self, seq): self._set_most_recent_seq(seq) def _set_most_recent_seq(self, seq): waagent.SetFileContents('mrseq', str(seq)) def do_status_report(self, operation, status, status_code, message): self.log("{0},{1},{2},{3}".format(operation, status, status_code, message)) tstamp = time.strftime(DateTimeFormat, time.gmtime()) stat = [{ "version": self._context._version, "timestampUTC": tstamp, "status": { "name": self._context._name, "operation": operation, "status": status, "code": status_code, "formattedMessage": { "lang": "en-US", "message": message } } }] stat_rept = json.dumps(stat) if self._context._status_file: tmp = "%s.tmp" % (self._context._status_file) with open(tmp, 'w+') as f: f.write(stat_rept) os.rename(tmp, self._context._status_file) def do_heartbeat_report(self, heartbeat_file, status, code, message): # heartbeat health_report = '[{"version":"1.0","heartbeat":{"status":"' + status + '","code":"' + code + '","Message":"' + message + '"}}]' if waagent.SetFileContents(heartbeat_file, health_report) == None: self.error('Unable to wite heartbeat info to ' + heartbeat_file) def do_exit(self, exit_code, operation, status, code, message): try: self.do_status_report(operation, status, code, message) except Exception as e: self.log("Can't update status: " + str(e)) sys.exit(exit_code) def get_name(self): return self._context._name def get_seq_no(self): return self._context._seq_no def get_log_dir(self): return self._context._log_dir def get_handler_settings(self): if (self._context._config != None): return self._context._config['runtimeSettings'][0]['handlerSettings'] return None def get_protected_settings(self): if (self._context._config != None): return self.get_handler_settings().get('protectedSettings') return None def get_public_settings(self): handlerSettings = self.get_handler_settings() if (handlerSettings != None): return self.get_handler_settings().get('publicSettings') return None def scrub_settings_file(self): content = waagent.GetFileContents(self._context._settings_file) redacted = HandlerUtility.redact_protected_settings(content) waagent.SetFileContents(self._context._settings_file, redacted) ================================================ FILE: Utils/LogUtil.py ================================================ # Logging utilities # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import string import sys OutputSize = 4 * 1024 def tail(log_file, output_size = OutputSize): pos = min(output_size, os.path.getsize(log_file)) with open(log_file, "r") as log: log.seek(0, os.SEEK_END) log.seek(log.tell() - pos, os.SEEK_SET) buf = log.read(output_size) buf = filter(lambda x: x in string.printable, buf) # encoding works different for between interpreter version, we are keeping separate implementation to ensure # backward compatibility if sys.version_info[0] == 3: buf = ''.join(list(buf)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: buf = buf.decode("ascii", "ignore") return buf def get_formatted_log(summary, stdout, stderr): msg_format = ("{0}\n" "---stdout---\n" "{1}\n" "---errout---\n" "{2}\n") return msg_format.format(summary, stdout, stderr) ================================================ FILE: Utils/ScriptUtil.py ================================================ # Script utilities # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import time import subprocess import traceback import string import shlex import sys from Utils import LogUtil from Utils.WAAgentUtil import waagent DefaultStdoutFile = "stdout" DefaultErroutFile = "errout" def run_command(hutil, args, cwd, operation, extension_short_name, version, exit_after_run=True, interval=30, std_out_file_name=DefaultStdoutFile, std_err_file_name=DefaultErroutFile): std_out_file = os.path.join(cwd, std_out_file_name) err_out_file = os.path.join(cwd, std_err_file_name) std_out = None err_out = None try: std_out = open(std_out_file, "w") err_out = open(err_out_file, "w") start_time = time.time() child = subprocess.Popen(args, cwd=cwd, stdout=std_out, stderr=err_out) time.sleep(1) while child.poll() is None: msg = "Command is running..." msg_with_cmd_output = LogUtil.get_formatted_log(msg, LogUtil.tail(std_out_file), LogUtil.tail(err_out_file)) msg_without_cmd_output = msg + " Stdout/Stderr omitted from output." hutil.log_to_file(msg_with_cmd_output) hutil.log_to_console(msg_without_cmd_output) hutil.do_status_report(operation, 'transitioning', '0', msg_without_cmd_output) time.sleep(interval) exit_code = child.returncode if child.returncode and child.returncode != 0: msg = "Command returned an error." msg_with_cmd_output = LogUtil.get_formatted_log(msg, LogUtil.tail(std_out_file), LogUtil.tail(err_out_file)) msg_without_cmd_output = msg + " Stdout/Stderr omitted from output." hutil.error(msg_without_cmd_output) waagent.AddExtensionEvent(name=extension_short_name, op=operation, isSuccess=False, version=version, message="(01302)" + msg_without_cmd_output) else: msg = "Command is finished." msg_with_cmd_output = LogUtil.get_formatted_log(msg, LogUtil.tail(std_out_file), LogUtil.tail(err_out_file)) msg_without_cmd_output = msg + " Stdout/Stderr omitted from output." hutil.log_to_file(msg_with_cmd_output) hutil.log_to_console(msg_without_cmd_output) waagent.AddExtensionEvent(name=extension_short_name, op=operation, isSuccess=True, version=version, message="(01302)" + msg_without_cmd_output) end_time = time.time() waagent.AddExtensionEvent(name=extension_short_name, op=operation, isSuccess=True, version=version, message=("(01304)Command execution time: " "{0}s").format(str(end_time - start_time))) log_or_exit(hutil, exit_after_run, exit_code, operation, msg_with_cmd_output) except Exception as e: error_msg = ("Failed to launch command with error: {0}," "stacktrace: {1}").format(e, traceback.format_exc()) hutil.error(error_msg) waagent.AddExtensionEvent(name=extension_short_name, op=operation, isSuccess=False, version=version, message="(01101)" + error_msg) exit_code = 1 msg = 'Launch command failed: {0}'.format(e) log_or_exit(hutil, exit_after_run, exit_code, operation, msg) finally: if std_out: std_out.close() if err_out: err_out.close() return exit_code # do_exit calls sys.exit which raises an exception so we do not call it from the finally block def log_or_exit(hutil, exit_after_run, exit_code, operation, msg): status = 'success' if exit_code == 0 else 'failed' if exit_after_run: hutil.do_exit(exit_code, operation, status, str(exit_code), msg) else: hutil.do_status_report(operation, status, str(exit_code), msg) def parse_args(cmd): cmd = filter(lambda x: x in string.printable, cmd) # encoding works different for between interpreter version, we are keeping separate implementation to ensure # backward compatibility if sys.version_info[0] == 3: cmd = ''.join(list(cmd)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: cmd = cmd.decode("ascii", "ignore") args = shlex.split(cmd) # From python 2.6 to python 2.7.2, shlex.split output UCS-4 result like # '\x00\x00a'. Temp workaround is to replace \x00 for idx, val in enumerate(args): if '\x00' in args[idx]: args[idx] = args[idx].replace('\x00', '') return args ================================================ FILE: Utils/WAAgentUtil.py ================================================ # Wrapper module for waagent # # waagent is not written as a module. This wrapper module is created # to use the waagent code as a module. # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import sys # imp was deprecated in python 3.12 if sys.version_info >= (3, 12): import importlib import types else: import imp # # The following code will search and load waagent code and expose # it as a submodule of current module # def searchWAAgent(): # if the extension ships waagent in its package to default to this version first pkg_agent_path = os.path.join(os.getcwd(), 'waagent') if os.path.isfile(pkg_agent_path): return pkg_agent_path agentPath = '/usr/sbin/waagent' if os.path.isfile(agentPath): return agentPath user_paths = os.environ['PYTHONPATH'].split(os.pathsep) for user_path in user_paths: agentPath = os.path.join(user_path, 'waagent') if os.path.isfile(agentPath): return agentPath return None waagent = None agentPath = searchWAAgent() if agentPath: # imp was deprecated in python 3.12 if sys.version_info >= (3, 12): # Create a module spec from the waagent python file, then create module from spec and load it loader = importlib.machinery.SourceFileLoader('waagent', agentPath) code = loader.get_code(loader.name) waagent = types.ModuleType(loader.name) exec(code, waagent.__dict__) # Add the module to sys.modules sys.modules['waagent'] = waagent else: waagent = imp.load_source('waagent', agentPath) else: raise Exception("Can't load waagent.") if not hasattr(waagent, "AddExtensionEvent"): """ If AddExtensionEvent is not defined, provide a dummy impl. """ def _AddExtensionEvent(*args, **kwargs): pass waagent.AddExtensionEvent = _AddExtensionEvent if not hasattr(waagent, "WALAEventOperation"): class _WALAEventOperation: HeartBeat="HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" waagent.WALAEventOperation = _WALAEventOperation # Better deal with the silly waagent typo, in anticipation of a proper fix of the typo later on waagent if not hasattr(waagent.WALAEventOperation, 'Uninstall'): if hasattr(waagent.WALAEventOperation, 'UnIsntall'): waagent.WALAEventOperation.Uninstall = waagent.WALAEventOperation.UnIsntall else: # This shouldn't happen, but just in case... waagent.WALAEventOperation.Uninstall = 'Uninstall' def GetWaagentHttpProxyConfigString(): """ Get http_proxy and https_proxy from waagent config. Username and password is not supported now. This code is adopted from /usr/sbin/waagent """ host = None port = None try: waagent.Config = waagent.ConfigurationProvider(None) # Use default waagent conf file (most likely /etc/waagent.conf) host = waagent.Config.get("HttpProxy.Host") port = waagent.Config.get("HttpProxy.Port") except Exception as e: # waagent.ConfigurationProvider(None) will throw an exception on an old waagent # Has to silently swallow because logging is not yet available here # and we don't want to bring that in here. Also if the call fails, then there's # no proxy config in waagent.conf anyway, so it's safe to silently swallow. pass result = '' if host is not None: result = "http://" + host if port is not None: result += ":" + port return result waagent.HttpProxyConfigString = GetWaagentHttpProxyConfigString() # end: waagent http proxy config stuff __ExtensionName__ = None def InitExtensionEventLog(name): global __ExtensionName__ __ExtensionName__ = name def AddExtensionEvent(name=__ExtensionName__, op=waagent.WALAEventOperation.Enable, isSuccess=False, message=None): if name is not None: waagent.AddExtensionEvent(name=name, op=op, isSuccess=isSuccess, message=message) ================================================ FILE: Utils/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: Utils/constants.py ================================================ LibDir = "/var/lib/waagent" Openssl = "openssl" os_release = "/etc/os-release" system_release = "/etc/system-release" class WALAEventOperation: HeartBeat = "HeartBeat" Provision = "Provision" Install = "Install" UnInstall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" ================================================ FILE: Utils/crypt_fallback.py ================================================ """ Fallback crypt implementation using ctypes for Python 3.13+. This module provides crypt.crypt() functionality without requiring pip install by directly calling the system's libxcrypt/libcrypt library via ctypes. Usage: # In your code, use this import pattern: try: import crypt except ImportError: try: import crypt_r as crypt except ImportError: from Common import crypt_fallback as crypt """ import ctypes import ctypes.util import string import random __all__ = ['crypt', 'mksalt', 'METHOD_SHA512', 'METHOD_SHA256', 'methods'] # Try to load libcrypt _libcrypt = None _libcrypt_path = ctypes.util.find_library("crypt") if _libcrypt_path: try: _libcrypt = ctypes.CDLL(_libcrypt_path) _libcrypt.crypt.argtypes = [ctypes.c_char_p, ctypes.c_char_p] _libcrypt.crypt.restype = ctypes.c_char_p except (OSError, AttributeError): _libcrypt = None class _Method: """Class representing a crypt method.""" def __init__(self, name, ident, salt_chars, total_size): self.name = name self.ident = ident self.salt_chars = salt_chars self.total_size = total_size def __repr__(self): return ''.format(self.name) # Define standard methods METHOD_SHA512 = _Method('SHA512', '6', 16, 106) METHOD_SHA256 = _Method('SHA256', '5', 16, 63) METHOD_MD5 = _Method('MD5', '1', 8, 34) METHOD_CRYPT = _Method('CRYPT', None, 2, 13) methods = [METHOD_SHA512, METHOD_SHA256, METHOD_MD5, METHOD_CRYPT] def mksalt(method=None, rounds=None): """Generate a salt for the specified method. If not specified, the strongest available method (SHA512) will be used. """ if method is None: method = METHOD_SHA512 saltchars = string.ascii_letters + string.digits + './' if method.ident: salt = '${0}$'.format(method.ident) if method.ident in ('5', '6') and rounds is not None: if not 1000 <= rounds <= 999999999: raise ValueError('rounds out of the range 1000 to 999_999_999') salt += 'rounds={0}$'.format(rounds) else: salt = '' salt += ''.join(random.choice(saltchars) for _ in range(method.salt_chars)) return salt def crypt(word, salt=None): """Return a string representing the one-way hash of a password. If salt is not specified, the strongest available method will be used. Args: word: The password to hash salt: The salt string (e.g., '$6$rounds=5000$saltsalt$') or a METHOD_* constant Returns: The hashed password string Raises: ImportError: If libcrypt is not available on the system """ if _libcrypt is None: raise ImportError( "crypt_fallback requires libcrypt/libxcrypt. " "Install with: sudo tdnf install libxcrypt (Azure Linux) or " "sudo apt install libcrypt1 (Debian/Ubuntu)" ) # Handle METHOD_* constants passed as salt if salt is None or isinstance(salt, _Method): salt = mksalt(salt) # Encode strings to bytes for ctypes if isinstance(word, str): word = word.encode('utf-8') if isinstance(salt, str): salt = salt.encode('utf-8') result = _libcrypt.crypt(word, salt) if result is None: raise ValueError("crypt() failed - invalid salt or system error") return result.decode('utf-8') ================================================ FILE: Utils/distroutils.py ================================================ import os import pwd import random import string import hashlib import sys # crypt module was removed in Python 3.13 # For Python < 3.11: use builtin crypt # For Python >= 3.11: try crypt_r package, then ctypes fallback if sys.version_info >= (3, 11): try: import crypt_r as crypt except ImportError: try: from Utils import crypt_fallback as crypt except ImportError: crypt = None else: try: import crypt except ImportError: try: from Utils import crypt_fallback as crypt except ImportError: crypt = None import platform import re import Utils.logger as logger import Utils.extensionutils as ext_utils import Utils.constants as constants def get_my_distro(config, os_name=None): if 'FreeBSD' in platform.system(): return FreeBSDDistro(config) if os_name is None: if os.path.isfile(constants.os_release): os_name = ext_utils.get_line_starting_with("NAME", constants.os_release) elif os.path.isfile(constants.system_release): os_name = ext_utils.get_file_contents(constants.system_release) if os_name is not None: if re.search("fedora", os_name, re.IGNORECASE): # Fedora return FedoraDistro(config) if re.search("red\s?hat", os_name, re.IGNORECASE): # Red Hat return RedhatDistro(config) if re.search("centos", os_name, re.IGNORECASE): # CentOS return CentOSDistro(config) if re.search("coreos", os_name, re.IGNORECASE): # CoreOs return CoreOSDistro(config) if re.search("freebsd", os_name, re.IGNORECASE): # FreeBSD return FreeBSDDistro(config) if re.search("sles", os_name, re.IGNORECASE): # SuSE return SuSEDistro(config) if re.search("ubuntu", os_name, re.IGNORECASE): return UbuntuDistro(config) if re.search("mariner", os_name, re.IGNORECASE): return MarinerDistro(config) return GenericDistro(config) # noinspection PyMethodMayBeStatic class GenericDistro(object): """ GenericiDstro defines a skeleton necessary for a concrete Distro class. Generic methods and attributes are kept here, distribution specific attributes and behavior are to be placed in the concrete child named distroDistro, where distro is the string returned by calling python platform.linux_distribution()[0]. So for CentOS the derived class is called 'centosDistro'. """ def __init__(self, config): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ self.selinux = None self.service_cmd = '/usr/sbin/service' self.ssh_service_restart_option = 'restart' self.ssh_service_name = 'ssh' self.distro_name = 'default' self.config = config def is_se_linux_system(self): """ Checks and sets self.selinux = True if SELinux is available on system. """ if self.selinux is None: if ext_utils.run(['which', 'getenforce'], chk_err=False): self.selinux = False else: self.selinux = True return self.selinux def get_home(self): """ Attempt to guess the $HOME location. Return the path string. """ home = None try: home = ext_utils.get_line_starting_with("HOME", "/etc/default/useradd").split('=')[1].strip() except (ValueError, KeyError, AttributeError, EnvironmentError): pass if (home is None) or (not home.startswith("/")): home = "/home" return home def set_se_linux_context(self, path, cn): """ Calls shell 'chcon' with 'path' and 'cn' context. Returns exit result. """ if self.is_se_linux_system(): return ext_utils.run(['chcon', cn, path]) def restart_ssh_service(self): """ Service call to re(start) the SSH service """ ssh_restart_cmd = [self.service_cmd, self.ssh_service_name, self.ssh_service_restart_option] ret_code = ext_utils.run(ssh_restart_cmd) if ret_code != 0: logger.error("Failed to restart SSH service with return code:" + str(ret_code)) return ret_code def ssh_deploy_public_key(self, fprint, path): """ Generic sshDeployPublicKey - over-ridden in some concrete Distro classes due to minor differences in openssl packages deployed """ keygen_retcode = ext_utils.run_command_and_write_stdout_to_file( ['ssh-keygen', '-i', '-m', 'PKCS8', '-f', fprint], path) if keygen_retcode: return 1 else: return 0 def change_password(self, user, password): logger.log("Change user password") crypt_id = self.config.get("Provisioning.PasswordCryptId") if crypt_id is None: crypt_id = "6" salt_len = self.config.get("Provisioning.PasswordCryptSaltLength") try: salt_len = int(salt_len) if salt_len < 0 or salt_len > 10: salt_len = 10 except (ValueError, TypeError): salt_len = 10 return self.chpasswd(user, password, crypt_id=crypt_id, salt_len=salt_len) def chpasswd(self, username, password, crypt_id=6, salt_len=10): passwd_hash = self.gen_password_hash(password, crypt_id, salt_len) cmd = ['usermod', '-p', passwd_hash, username] ret, output = ext_utils.run_command_get_output(cmd, log_cmd=False) if ret != 0: return "Failed to set password for {0}: {1}".format(username, output) def gen_password_hash(self, password, crypt_id, salt_len): collection = string.ascii_letters + string.digits salt = ''.join(random.choice(collection) for _ in range(salt_len)) salt = "${0}${1}".format(crypt_id, salt) return crypt.crypt(password, salt) def create_account(self, user, password, expiration, thumbprint, enable_nopasswd): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ user_entry = None try: user_entry = pwd.getpwnam(user) except (KeyError, EnvironmentError): pass uid_min = None try: uid_min = int(ext_utils.get_line_starting_with("UID_MIN", "/etc/login.defs").split()[1]) except (ValueError, KeyError, AttributeError, EnvironmentError): pass if uid_min is None: uid_min = 100 if user_entry is not None and user_entry[2] < uid_min: logger.error( "CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if user_entry is None: command = ['useradd', '-m', user] if expiration is not None: command += ['-e', expiration.split('.')[0]] if ext_utils.run(command): logger.error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: logger.log("CreateAccount: " + user + " already exists. Will update password.") if password is not None: self.change_password(user, password) try: # for older distros create sudoers.d if not os.path.isdir('/etc/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir('/etc/sudoers.d/') # add the include of sudoers.d to the /etc/sudoers ext_utils.set_file_contents( '/etc/sudoers', ext_utils.get_file_contents('/etc/sudoers') + '\n#includedir /etc/sudoers.d\n') if password is None or enable_nopasswd: ext_utils.set_file_contents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: ext_utils.set_file_contents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except EnvironmentError: logger.error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = self.get_home() if thumbprint is not None: ssh_dir = home + "/" + user + "/.ssh" ext_utils.create_dir(ssh_dir, user, 0o700) pub = ssh_dir + "/id_rsa.pub" prv = ssh_dir + "/id_rsa" ext_utils.run_command_and_write_stdout_to_file(['ssh-keygen', '-y', '-f', thumbprint + '.prv'], pub) for f in [pub, prv]: os.chmod(f, 0o600) ext_utils.change_owner(f, user) ext_utils.set_file_contents(ssh_dir + "/authorized_keys", ext_utils.get_file_contents(pub)) ext_utils.change_owner(ssh_dir + "/authorized_keys", user) logger.log("Created user account: " + user) return None def delete_account(self, user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ user_entry = None try: user_entry = pwd.getpwnam(user) except (KeyError, EnvironmentError): pass if user_entry is None: logger.error("DeleteAccount: " + user + " not found.") return uid_min = None try: uid_min = int(ext_utils.get_line_starting_with("UID_MIN", "/etc/login.defs").split()[1]) except (ValueError, KeyError, AttributeError, EnvironmentError): pass if uid_min is None: uid_min = 100 if user_entry[2] < uid_min: logger.error( "DeleteAccount: " + user + " is a system user. Will not delete account.") return ext_utils.run(['rm', '-f', '/var/run/utmp']) # Delete utmp to prevent error if we are the 'user' deleted ext_utils.run(['userdel', '-f', '-r', user]) try: os.remove("/etc/sudoers.d/waagent") except EnvironmentError: pass return class UbuntuDistro(GenericDistro): def __init__(self, config): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ super(UbuntuDistro, self).__init__(config) self.selinux = False self.ssh_service_name = 'sshd' self.sudoers_dir_base = '/usr/local/etc' self.distro_name = 'Ubuntu' def restart_ssh_service(self): """ Service call to re(start) the SSH service starting with Ubuntu 22.10, the service name is ssh not sshd, adding fallback incase sshd fails """ ssh_restart_cmd = [self.service_cmd, self.ssh_service_name, self.ssh_service_restart_option] ret_code = ext_utils.run(ssh_restart_cmd) if ret_code != 0: self.ssh_service_name = 'ssh' ssh_restart_cmd = [self.service_cmd, self.ssh_service_name, self.ssh_service_restart_option] ret_code = ext_utils.run(ssh_restart_cmd) if ret_code != 0: logger.error("Failed to restart SSH service with return code:" + str(ret_code)) return ret_code class FreeBSDDistro(GenericDistro): """ """ def __init__(self, config): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ super(FreeBSDDistro, self).__init__(config) self.selinux = False self.ssh_service_name = 'sshd' self.sudoers_dir_base = '/usr/local/etc' self.distro_name = 'FreeBSD' # noinspection PyMethodOverriding def chpasswd(self, user, password, **kwargs): return ext_utils.run_send_stdin(['pw', 'usermod', 'user', '-h', '0'], password.encode('utf-8'), log_cmd=False) def create_account(self, user, password, expiration, thumbprint, enable_nopasswd): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except (EnvironmentError, KeyError): pass uidmin = None try: if os.path.isfile("/etc/pw.conf"): uidmin = int(ext_utils.get_line_starting_with("minuid", "/etc/pw.conf").split('=')[1].strip(' "')) except (ValueError, KeyError, AttributeError, EnvironmentError): pass pass if uidmin is None: uidmin = 100 if userentry is not None and userentry[2] < uidmin: logger.error( "CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry is None: command = ['pw', 'useradd', user, '-m'] if expiration is not None: command += ['-e', expiration.split('.')[0]] if ext_utils.run(command): logger.error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: logger.log( "CreateAccount: " + user + " already exists. Will update password.") if password is not None: self.change_password(user, password) try: # for older distros create sudoers.d if not os.path.isdir(self.sudoers_dir_base + '/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir(self.sudoers_dir_base + '/sudoers.d') # add the include of sudoers.d to the /etc/sudoers ext_utils.set_file_contents( self.sudoers_dir_base + '/sudoers', ext_utils.get_file_contents( self.sudoers_dir_base + '/sudoers') + '\n#includedir ' + self.sudoers_dir_base + '/sudoers.d\n') if password is None or enable_nopasswd: ext_utils.set_file_contents( self.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: ext_utils.set_file_contents(self.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod(self.sudoers_dir_base + "/sudoers.d/waagent", 0o440) except (ValueError, KeyError, AttributeError, EnvironmentError): logger.error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = self.get_home() if thumbprint is not None: ssh_dir = home + "/" + user + "/.ssh" ext_utils.create_dir(ssh_dir, user, 0o700) pub = ssh_dir + "/id_rsa.pub" prv = ssh_dir + "/id_rsa" ext_utils.run_command_and_write_stdout_to_file(['sh-keygen', '-y', '-f', thumbprint + '.prv'], pub) ext_utils.set_file_contents( prv, ext_utils.get_file_contents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ext_utils.change_owner(f, user) ext_utils.set_file_contents( ssh_dir + "/authorized_keys", ext_utils.get_file_contents(pub)) ext_utils.change_owner(ssh_dir + "/authorized_keys", user) logger.log("Created user account: " + user) return None def delete_account(self, user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except (EnvironmentError, KeyError): pass if userentry is None: logger.error("DeleteAccount: " + user + " not found.") return uidmin = None try: if os.path.isfile("/etc/pw.conf"): uidmin = int(ext_utils.get_line_starting_with("minuid", "/etc/pw.conf").split('=')[1].strip(' "')) except (ValueError, KeyError, AttributeError, EnvironmentError): pass if uidmin is None: uidmin = 100 if userentry[2] < uidmin: logger.error( "DeleteAccount: " + user + " is a system user. Will not delete account.") return # empty contents of utmp to prevent error if we are the 'user' deleted ext_utils.run_command_and_write_stdout_to_file(['echo'], '/var/run/utmp') ext_utils.run(['rmuser', '-y', user], chk_err=False) try: os.remove(self.sudoers_dir_base + "/sudoers.d/waagent") except EnvironmentError: pass return def get_home(self): return '/home' class CoreOSDistro(GenericDistro): """ CoreOS Distro concrete class Put CoreOS specific behavior here... """ CORE_UID = 500 def __init__(self, config): super(CoreOSDistro, self).__init__(config) self.waagent_path = '/usr/share/oem/bin' self.python_path = '/usr/share/oem/python/bin' self.distro_name = 'CoreOS' if 'PATH' in os.environ: os.environ['PATH'] = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: os.environ['PATH'] = self.python_path if 'PYTHONPATH' in os.environ: os.environ['PYTHONPATH'] = "{0}:{1}".format(os.environ['PYTHONPATH'], self.waagent_path) else: os.environ['PYTHONPATH'] = self.waagent_path def restart_ssh_service(self): """ SSH is socket activated on CoreOS. No need to restart it. """ return 0 def create_account(self, user, password, expiration, thumbprint, enable_nopasswd): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except (EnvironmentError, KeyError): pass uidmin = None try: uidmin = int(ext_utils.get_line_starting_with("UID_MIN", "/etc/login.defs").split()[1]) except (ValueError, KeyError, AttributeError, EnvironmentError): pass if uidmin is None: uidmin = 100 if userentry is not None and userentry[2] < uidmin and userentry[2] != self.CORE_UID: logger.error( "CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry is None: command = ['useradd', '--create-home', '--password', '*', user] if expiration is not None: command += ['--expiredate', expiration.split('.')[0]] if ext_utils.run(command): logger.error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: logger.log("CreateAccount: " + user + " already exists. Will update password.") if password is not None: self.change_password(user, password) try: if password is None or enable_nopasswd: ext_utils.set_file_contents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: ext_utils.set_file_contents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except EnvironmentError: logger.error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = self.get_home() if thumbprint is not None: ssh_dir = home + "/" + user + "/.ssh" ext_utils.create_dir(ssh_dir, user, 0o700) pub = ssh_dir + "/id_rsa.pub" prv = ssh_dir + "/id_rsa" ext_utils.run_command_and_write_stdout_to_file(['ssh-keygen', '-y', '-f', thumbprint + '.prv'], pub) ext_utils.set_file_contents(prv, ext_utils.get_file_contents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ext_utils.change_owner(f, user) ext_utils.set_file_contents(ssh_dir + "/authorized_keys", ext_utils.get_file_contents(pub)) ext_utils.change_owner(ssh_dir + "/authorized_keys", user) logger.log("Created user account: " + user) return None class RedhatDistro(GenericDistro): """ Redhat Distro concrete class Put Redhat specific behavior here... """ def __init__(self, config): super(RedhatDistro, self).__init__(config) self.service_cmd = '/sbin/service' self.ssh_service_restart_option = 'condrestart' self.ssh_service_name = 'sshd' self.distro_name = 'Red Hat' class CentOSDistro(RedhatDistro): def __init__(self, config): super(CentOSDistro, self).__init__(config) self.distro_name = "CentOS" class FedoraDistro(RedhatDistro): """ FedoraDistro concrete class Put Fedora specific behavior here... """ def __init__(self, config): super(FedoraDistro, self).__init__(config) self.service_cmd = '/usr/bin/systemctl' self.hostname_file_path = '/etc/hostname' self.distro_name = 'Fedora' def restart_ssh_service(self): """ Service call to re(start) the SSH service """ ssh_restart_cmd = [self.service_cmd, self.ssh_service_restart_option, self.ssh_service_name] retcode = ext_utils.run(ssh_restart_cmd) if retcode > 0: logger.error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def create_account(self, user, password, expiration, thumbprint, enable_nopasswd): ext_utils.run(['/sbin/usermod', user, '-G', 'wheel']) def delete_account(self, user): ext_utils.run(['/sbin/usermod', user, '-G', '']) class SuSEDistro(GenericDistro): def __init__(self, config): super(SuSEDistro, self).__init__(config) self.ssh_service_name = 'sshd' self.distro_name = "SuSE" class MarinerDistro(GenericDistro): def __init__(self, config): super(MarinerDistro, self).__init__(config) self.ssh_service_name = 'sshd' self.service_cmd = '/usr/bin/systemctl' self.distro_name = 'Mariner' def restart_ssh_service(self): """ Service call to re(start) the SSH service """ ssh_restart_cmd = [self.service_cmd, self.ssh_service_restart_option, self.ssh_service_name] retcode = ext_utils.run(ssh_restart_cmd) if retcode > 0: logger.error("Failed to restart SSH service with return code:" + str(retcode)) return retcode ================================================ FILE: Utils/extensionutils.py ================================================ import subprocess import os import tempfile import traceback import time import sys import pwd import Utils.constants as constants import xml.sax.saxutils as xml_utils import Utils.logger as logger if not hasattr(subprocess, 'check_output'): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output = check_output subprocess.CalledProcessError = CalledProcessError def change_owner(file_path, user): """ Lookup user. Attempt chown 'filepath' to 'user'. """ p = None try: p = pwd.getpwnam(user) except (KeyError, EnvironmentError): pass if p is not None: os.chown(file_path, p[2], p[3]) def create_dir(dir_path, user, mode): """ Attempt os.makedirs, catch all exceptions. Call ChangeOwner afterwards. """ try: os.makedirs(dir_path, mode) except EnvironmentError: pass change_owner(dir_path, user) def encode_for_writing_to_file(contents): if type(contents) == str: if sys.version_info[0] == 3: """ utf-8 is a superset of ASCII and latin-1 in python 2 str is an alias for bytes, no need to encode it again """ return contents.encode('utf-8') return contents def set_file_contents(file_path, contents): """ Write 'contents' to 'file_path'. """ bytes_to_write = encode_for_writing_to_file(contents) try: with open(file_path, "wb+") as F: F.write(bytes_to_write) except EnvironmentError as e: logger.error_with_prefix( 'SetFileContents', 'Writing to file ' + file_path + ' Exception is ' + str(e)) return None return 0 def append_file_contents(file_path, contents): """ Append 'contents' to 'file_path'. """ bytes_to_write = encode_for_writing_to_file(contents) try: with open(file_path, "ab+") as F: F.write(bytes_to_write) except EnvironmentError as e: logger.error_with_prefix( 'AppendFileContents', 'Appending to file ' + file_path + ' Exception is ' + str(e)) return None return 0 def get_file_contents(file_path, as_bin=False): """ Read and return contents of 'file_path'. """ mode = 'r' if as_bin: mode += 'b' try: with open(file_path, mode) as F: contents = F.read() return contents except EnvironmentError as e: logger.error_with_prefix( 'GetFileContents', 'Reading from file ' + file_path + ' Exception is ' + str(e)) return None def replace_file_with_contents_atomic(filepath, contents): """ Write 'contents' to 'filepath' by creating a temp file, and replacing original. """ handle, temp = tempfile.mkstemp(dir=os.path.dirname(filepath)) bytes_to_write = encode_for_writing_to_file(contents) try: os.write(handle, bytes_to_write) except EnvironmentError as e: logger.error_with_prefix( 'ReplaceFileContentsAtomic', 'Writing to file ' + filepath + ' Exception is ' + str(e)) return None finally: os.close(handle) try: os.rename(temp, filepath) return None except EnvironmentError as e: logger.error_with_prefix( 'ReplaceFileContentsAtomic', 'Renaming ' + temp + ' to ' + filepath + ' Exception is ' + str(e) ) try: os.remove(filepath) except EnvironmentError as e: logger.error_with_prefix( 'ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) try: os.rename(temp, filepath) except EnvironmentError as e: logger.error_with_prefix( 'ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) return 1 return 0 def run_command_and_write_stdout_to_file(command, output_file): # meant to replace commands of the nature command > output_file try: p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) stdout, stderr = p.communicate() except EnvironmentError as e: logger.error('CalledProcessError. Error message is ' + str(e)) return e.errno if p.returncode != 0: logger.error('CalledProcessError. Error Code is ' + str(p.returncode)) logger.error('CalledProcessError. Command string was ' + ' '.join(command)) logger.error( 'CalledProcessError. Command result was stdout: ' + str(stdout) + ' stderr: ' + str(stderr)) return p.returncode set_file_contents(output_file, stdout) return 0 def run_command_get_output(cmd, chk_err=True, log_cmd=True): """ Wrapper for subprocess.check_output. Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: logger.log_if_verbose(cmd) try: output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False) except subprocess.CalledProcessError as e: if chk_err and log_cmd: logger.error('CalledProcessError. Error Code is ' + str(e.returncode)) logger.error('CalledProcessError. Command string was ' + str(cmd)) logger.error( 'CalledProcessError. Command result was ' + (e.output[:-1]).decode('utf-8')) return e.returncode, e.output.decode('utf-8') except EnvironmentError as e: if chk_err and log_cmd: logger.error( 'CalledProcessError. Error message is ' + str(e)) return e.errno, str(e) # noinspection PyUnboundLocalVariable return 0, output.decode('utf-8') def run(cmd, chk_err=True): """ Calls RunGetOutput on 'cmd', returning only the return code. If chk_err=True then errors will be reported in the log. If chk_err=False then errors will be suppressed from the log. """ return_code, _ = run_command_get_output(cmd, chk_err) return return_code # noinspection PyUnboundLocalVariable def run_send_stdin(cmd, cmd_input, chk_err=True, log_cmd=True): """ Wrapper for subprocess.Popen. Execute 'cmd', sending 'input' to STDIN of 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: logger.log_if_verbose(str(cmd) + str(cmd_input)) subprocess_executed = False try: me = subprocess.Popen(cmd, shell=False, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) output = me.communicate(cmd_input) subprocess_executed = True except EnvironmentError as e: if chk_err and log_cmd: logger.error('CalledProcessError. Error Code is ' + str(e.errno)) logger.error('CalledProcessError. Command was ' + str(cmd)) logger.error('CalledProcessError. Command result was ' + str(e)) return 1, str(e) if subprocess_executed and me.returncode != 0 and chk_err and log_cmd: logger.error('CalledProcessError. Error Code is ' + str(me.returncode)) logger.error('CalledProcessError. Command was ' + str(cmd)) logger.error( 'CalledProcessError. Command result was ' + output[0].decode('utf-8')) return me.returncode, output[0].decode('utf-8') def get_line_starting_with(prefix, filepath): """ Return line from 'filepath' if the line startswith 'prefix' """ for line in get_file_contents(filepath).split('\n'): if line.startswith(prefix): return line return None class WALAEvent(object): def __init__(self): self.providerId = "" self.eventId = 1 self.OpcodeName = "" self.KeywordName = "" self.TaskName = "" self.TenantName = "" self.RoleName = "" self.RoleInstanceName = "" self.ContainerId = "" self.ExecutionMode = "IAAS" self.OSVersion = "" self.GAVersion = "" self.RAM = 0 self.Processors = 0 def to_xml(self): str_event_id = u''.format(self.eventId) str_provider_id = u''.format(self.providerId) str_record_format = u'' str_record_no_quote_format = u'' str_mt_str = u'mt:wstr' str_mt_u_int64 = u'mt:uint64' str_mt_bool = u'mt:bool' str_mt_float = u'mt:float64' str_events_data = u"" for attName in self.__dict__: if attName in ["eventId", "filedCount", "providerId"]: continue att_value = self.__dict__[attName] if type(att_value) is int: str_events_data += str_record_format.format(attName, att_value, str_mt_u_int64) continue if type(att_value) is str: att_value = xml_utils.quoteattr(att_value) str_events_data += str_record_no_quote_format.format(attName, att_value, str_mt_str) continue if str(type(att_value)).count("'unicode'") > 0: att_value = xml_utils.quoteattr(att_value) str_events_data += str_record_no_quote_format.format(attName, att_value, str_mt_str) continue if type(att_value) is bool: str_events_data += str_record_format.format(attName, att_value, str_mt_bool) continue if type(att_value) is float: str_events_data += str_record_format.format(attName, att_value, str_mt_float) continue logger.log( "Warning: property " + attName + ":" + str(type(att_value)) + ":type" + str(type(att_value)) + "Can't convert to events data:" + ":type not supported") return u"{0}{1}{2}".format(str_provider_id, str_event_id, str_events_data) def save(self): event_folder = constants.LibDir + "/events" if not os.path.exists(event_folder): os.mkdir(event_folder) os.chmod(event_folder, 0o700) if len(os.listdir(event_folder)) > 1000: logger.log("Warning: Too many files under " + event_folder) filename = os.path.join(event_folder, str(int(time.time() * 1000000))) with open(filename + ".tmp", 'wb+') as h_file: h_file.write(self.to_xml().encode("utf-8")) os.rename(filename + ".tmp", filename + ".tld") class ExtensionEvent(WALAEvent): def __init__(self): WALAEvent.__init__(self) self.eventId = 1 self.providerId = "69B669B9-4AF8-4C50-BDC4-6006FA76E975" self.Name = "" self.Version = "" self.IsInternal = False self.Operation = "" self.OperationSuccess = True self.ExtensionType = "" self.Message = "" self.Duration = 0 def add_extension_event(name, op, is_success, duration=0, version="1.0", message="", extension_type="", is_internal=False): event = ExtensionEvent() event.Name = name event.Version = version event.IsInternal = is_internal event.Operation = op event.OperationSuccess = is_success event.Message = message event.Duration = duration event.ExtensionType = extension_type try: event.save() except EnvironmentError: logger.error("Error " + traceback.format_exc()) ================================================ FILE: Utils/handlerutil2.py ================================================ # # Handler library for Linux IaaS # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "", "configFolder": "", "statusFolder": "", "heartbeatFile": "", } }] Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Example Status Report: [{"version":"1.0","timestampUTC":"2014-05-29T04:20:13Z","status":{"name":"Chef Extension Handler","operation":"chef-client-run","status":"success","code":0,"formattedMessage":{"lang":"en-US","message":"Chef-client run success"}}}] """ import os import os.path import sys import base64 import json import time import re import Utils.extensionutils as ext_utils import Utils.constants as constants import Utils.logger as logger from xml.etree import ElementTree from os.path import join DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ" MANIFEST_XML = "manifest.xml" ENV_CONFIG_SEQUENCE_NUMBER = "ConfigSequenceNumber" class HandlerContext: def __init__(self, name): self._name = name self._version = '0.0' self._config_dir = None self._log_dir = None self._log_file = None self._status_dir = None self._heartbeat_file = None self._seq_no = -1 self._status_file = None self._settings_file = None self._config = None return class HandlerUtility: def __init__(self, s_name=None, l_name=None, extension_version=None, logFileName='extension.log', console_logger=None, file_logger=None): self._log = logger.log self._log_to_con = console_logger self._log_to_file = file_logger self._error = logger.error self._logFileName = logFileName if s_name is None or l_name is None or extension_version is None: (l_name, s_name, extension_version) = self._get_extension_info() self._short_name = s_name self._extension_version = extension_version self._log_prefix = '[%s-%s] ' % (l_name, extension_version) def get_extension_version(self): return self._extension_version def _get_log_prefix(self): return self._log_prefix def _get_extension_info(self): if os.path.isfile(MANIFEST_XML): return self._get_extension_info_manifest() ext_dir = os.path.basename(os.getcwd()) (long_name, version) = ext_dir.split('-') short_name = long_name.split('.')[-1] return long_name, short_name, version def _get_extension_info_manifest(self): with open(MANIFEST_XML) as fh: doc = ElementTree.parse(fh) namespace = doc.find('{http://schemas.microsoft.com/windowsazure}ProviderNameSpace').text short_name = doc.find('{http://schemas.microsoft.com/windowsazure}Type').text version = doc.find('{http://schemas.microsoft.com/windowsazure}Version').text long_name = "%s.%s" % (namespace, short_name) return (long_name, short_name, version) def _get_current_seq_no(self, config_folder): seq_no = -1 cur_seq_no = -1 freshest_time = None # First read the sequence number from the environment variable seq_no_from_env = os.getenv(ENV_CONFIG_SEQUENCE_NUMBER) if (seq_no_from_env is not None): try: seq_no = int(seq_no_from_env) except ValueError: self.error("Unable to convert sequence number to int:" + seq_no_from_env) if seq_no == -1: # Otherwise look for the most recent sequence number from the files self.log("Searching for sequence number in config folder: " + config_folder) for subdir, dirs, file_names in os.walk(config_folder): for file_name in file_names: try: file_basename = os.path.basename(file_name) if "." in file_basename: cur_seq_no = int(file_basename.split('.')[0]) if (freshest_time is None): freshest_time = os.path.getmtime(join(config_folder, file_name)) seq_no = cur_seq_no else: current_file_m_time = os.path.getmtime(join(config_folder, file_name)) if (current_file_m_time > freshest_time): freshest_time = current_file_m_time seq_no = cur_seq_no except ValueError: continue return seq_no def log(self, message): self._log(self._get_log_prefix() + message) def log_to_console(self, message): if self._log_to_con is not None: self._log_to_con(self._get_log_prefix() + message) else: self.error("Unable to log to console, console log method not set") def log_to_file(self, message): if self._log_to_file is not None: self._log_to_file(self._get_log_prefix() + message) else: self.error("Unable to log to file, file log method not set") def error(self, message): self._error(self._get_log_prefix() + message) @staticmethod def redact_protected_settings(content): redacted_tmp = re.sub('"protectedSettings":\s*"[^"]+=="', '"protectedSettings": "*** REDACTED ***"', content) redacted = re.sub('"protectedSettingsCertThumbprint":\s*"[^"]+"', '"protectedSettingsCertThumbprint": "*** REDACTED ***"', redacted_tmp) return redacted def _parse_config(self, ctxt): config = None try: config = json.loads(ctxt) except: self.error('JSON exception decoding ' + HandlerUtility.redact_protected_settings(ctxt)) if config is None: self.error("JSON error processing settings file:" + HandlerUtility.redact_protected_settings(ctxt)) else: handlerSettings = config['runtimeSettings'][0]['handlerSettings'] if 'protectedSettings' in handlerSettings and \ 'protectedSettingsCertThumbprint' in handlerSettings and \ handlerSettings['protectedSettings'] is not None and \ handlerSettings["protectedSettingsCertThumbprint"] is not None: protectedSettings = handlerSettings['protectedSettings'] thumb = handlerSettings['protectedSettingsCertThumbprint'] cert = constants.LibDir + '/' + thumb + '.crt' pkey = constants.LibDir + '/' + thumb + '.prv' unencodedSettings = base64.standard_b64decode(protectedSettings) openSSLcmd_cms = ['openssl', 'cms', '-inform', 'DER', '-decrypt', '-recip' , cert, '-inkey', pkey] cleartxt = ext_utils.run_send_stdin(openSSLcmd_cms, unencodedSettings)[1] if cleartxt is None: self.log("OpenSSL decode error using cms command with thumbprint " + thumb + "\n trying smime command") openSSLcmd_smime = ['openssl', 'smime', '-inform', 'DER', '-decrypt', '-recip' , cert, '-inkey', pkey] cleartxt = ext_utils.run_send_stdin(openSSLcmd_smime, unencodedSettings)[1] if cleartxt is None: self.error("OpenSSL decode error using smime command with thumbprint " + thumb) self.do_exit(1, "Enable", 'error', '1', 'Failed to decrypt protectedSettings') jctxt = '' try: jctxt = json.loads(cleartxt) except: self.error('JSON exception decoding ' + HandlerUtility.redact_protected_settings(cleartxt)) handlerSettings['protectedSettings']=jctxt self.log('Config decoded correctly.') return config def do_parse_context(self, operation): _context = self.try_parse_context() if not _context: self.do_exit(1, operation, 'error', '1', operation + ' Failed') return _context def try_parse_context(self): self._context = HandlerContext(self._short_name) handler_env = None config = None ctxt = None code = 0 # get the HandlerEnvironment.json. According to the extension handler spec, it is always in the ./ directory self.log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("Unable to locate " + handler_env_file) return None ctxt = ext_utils.get_file_contents(handler_env_file) if ctxt == None: self.error("Unable to read " + handler_env_file) try: handler_env = json.loads(ctxt) except: pass if handler_env == None: self.log("JSON error processing " + handler_env_file) return None if type(handler_env) == list: handler_env = handler_env[0] self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context._log_dir = handler_env['handlerEnvironment']['logFolder'] self._context._log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'], self._logFileName) self._change_log_file() self._context._status_dir = handler_env['handlerEnvironment']['statusFolder'] self._context._heartbeat_file = handler_env['handlerEnvironment']['heartbeatFile'] self._context._seq_no = self._get_current_seq_no(self._context._config_dir) if self._context._seq_no < 0: self.error("Unable to locate a .settings file!") return None self._context._seq_no = str(self._context._seq_no) self.log('sequence number is ' + self._context._seq_no) self._context._status_file = os.path.join(self._context._status_dir, self._context._seq_no + '.status') self._context._settings_file = os.path.join(self._context._config_dir, self._context._seq_no + '.settings') self.log("setting file path is" + self._context._settings_file) ctxt = None ctxt = ext_utils.get_file_contents(self._context._settings_file) if ctxt == None: error_msg = 'Unable to read ' + self._context._settings_file + '. ' self.error(error_msg) return None self.log("JSON config: " + HandlerUtility.redact_protected_settings(ctxt)) self._context._config = self._parse_config(ctxt) return self._context def _change_log_file(self): self.log("Change log file to " + self._context._log_file) # this will change the logging file for all python files that share the same process logger.global_shared_context_logger = logger.Logger(self._context._log_file, '/dev/stdout') def is_seq_smaller(self): return int(self._context._seq_no) <= self._get_most_recent_seq() def save_seq(self): self._set_most_recent_seq(self._context._seq_no) self.log("set most recent sequence number to " + str(self._context._seq_no)) def exit_if_enabled(self, remove_protected_settings=False): self.exit_if_seq_smaller(remove_protected_settings) def exit_if_seq_smaller(self, remove_protected_settings): if(self.is_seq_smaller()): self.log( "Current sequence number, " + str(self._context._seq_no) + ", is not greater than the sequence number of the most recent executed configuration. Exiting...") sys.exit(0) self.save_seq() if remove_protected_settings: self.scrub_settings_file() def _get_most_recent_seq(self): if (os.path.isfile('mrseq')): seq = ext_utils.get_file_contents('mrseq') if (seq): return int(seq) return -1 def is_current_config_seq_greater_inused(self): return int(self._context._seq_no) > self._get_most_recent_seq() def get_inused_config_seq(self): return self._get_most_recent_seq() def set_inused_config_seq(self, seq): self._set_most_recent_seq(seq) def _set_most_recent_seq(self, seq): ext_utils.set_file_contents('mrseq', str(seq)) def do_status_report(self, operation, status, status_code, message): self.log("{0},{1},{2},{3}".format(operation, status, status_code, message)) tstamp = time.strftime(DateTimeFormat, time.gmtime()) stat = [{ "version": self._context._version, "timestampUTC": tstamp, "status": { "name": self._context._name, "operation": operation, "status": status, "code": status_code, "formattedMessage": { "lang": "en-US", "message": message } } }] stat_rept = json.dumps(stat) if self._context._status_file: tmp = "%s.tmp" % (self._context._status_file) with open(tmp, 'w+') as f: f.write(stat_rept) os.rename(tmp, self._context._status_file) def do_heartbeat_report(self, heartbeat_file, status, code, message): # heartbeat health_report = '[{"version":"1.0","heartbeat":{"status":"' + status + '","code":"' + code + '","Message":"' + message + '"}}]' if ext_utils.set_file_contents(heartbeat_file, health_report) is None: self.error('Unable to wite heartbeat info to ' + heartbeat_file) def do_exit(self, exit_code, operation, status, code, message): try: self.do_status_report(operation, status, code, message) except Exception as e: self.log("Can't update status: " + str(e)) sys.exit(exit_code) def get_name(self): return self._context._name def get_seq_no(self): return self._context._seq_no def get_log_dir(self): return self._context._log_dir def get_handler_settings(self): if (self._context._config != None): return self._context._config['runtimeSettings'][0]['handlerSettings'] return None def get_protected_settings(self): if (self._context._config != None): protectedSettings = self.get_handler_settings().get('protectedSettings') if (isinstance(protectedSettings, dict)): return protectedSettings else: self.error("Protected settings is not of type dictionary") return None def get_public_settings(self): handlerSettings = self.get_handler_settings() if (handlerSettings != None): return self.get_handler_settings().get('publicSettings') return None def scrub_settings_file(self): content = ext_utils.get_file_contents(self._context._settings_file) redacted = HandlerUtility.redact_protected_settings(content) ext_utils.set_file_contents(self._context._settings_file, redacted) ================================================ FILE: Utils/logger.py ================================================ import time import sys import string # noinspection PyMethodMayBeStatic class Logger(object): """ The Agent's logging assumptions are: For Log, and LogWithPrefix all messages are logged to the self.file_path and to the self.con_path. Setting either path parameter to None skips that log. If Verbose is enabled, messages calling the LogIfVerbose method will be logged to file_path yet not to con_path. Error and Warn messages are normal log messages with the 'ERROR:' or 'WARNING:' prefix added. """ def __init__(self, filepath, conpath, verbose=False): """ Construct an instance of Logger. """ self.file_path = filepath self.con_path = conpath self.verbose = verbose def throttle_log(self, counter): """ Log everything up to 10, every 10 up to 100, then every 100. """ return (counter < 10) or ((counter < 100) and ((counter % 10) == 0)) or ((counter % 100) == 0) def write_to_file(self, message): """ Write 'message' to logfile. """ if self.file_path: try: with open(self.file_path, "a") as F: message = filter(lambda x: x in string.printable, message) # encoding works different for between interpreter version, we are keeping separate implementation # to ensure backward compatibility if sys.version_info[0] == 3: message = ''.join(list(message)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: message = message.encode('ascii', 'ignore') F.write(message + "\n") except IOError as e: pass def write_to_console(self, message): """ Write 'message' to /dev/console. This supports serial port logging if the /dev/console is redirected to ttys0 in kernel boot options. """ if self.con_path: try: with open(self.con_path, "w") as C: message = filter(lambda x: x in string.printable, message) # encoding works different for between interpreter version, we are keeping separate implementation # to ensure backward compatibility if sys.version_info[0] == 3: message = ''.join(list(message)).encode('ascii', 'ignore').decode("ascii", "ignore") elif sys.version_info[0] == 2: message = message.encode('ascii', 'ignore') C.write(message + "\n") except IOError as e: pass def log(self, message): """ Standard Log function. Logs to self.file_path, and con_path """ self.log_with_prefix("", message) def log_to_console(self, message): """ Logs message to console by pre-pending each line of 'message' with current time. """ log_prefix = self._get_log_prefix("") for line in message.split('\n'): line = log_prefix + line self.write_to_console(line) def log_to_file(self, message): """ Logs message to file by pre-pending each line of 'message' with current time. """ log_prefix = self._get_log_prefix("") for line in message.split('\n'): line = log_prefix + line self.write_to_file(line) def no_log(self, message): """ Don't Log. """ pass def log_if_verbose(self, message): """ Only log 'message' if global Verbose is True. """ self.log_with_prefix_if_verbose('', message) def log_with_prefix(self, prefix, message): """ Prefix each line of 'message' with current time+'prefix'. """ log_prefix = self._get_log_prefix(prefix) for line in message.split('\n'): line = log_prefix + line self.write_to_file(line) self.write_to_console(line) def log_with_prefix_if_verbose(self, prefix, message): """ Only log 'message' if global Verbose is True. Prefix each line of 'message' with current time+'prefix'. """ if self.verbose: log_prefix = self._get_log_prefix(prefix) for line in message.split('\n'): line = log_prefix + line self.write_to_file(line) self.write_to_console(line) def warning(self, message): self.log_with_prefix("WARNING:", message) def error_with_prefix(self, prefix, message): self.log_with_prefix("ERROR: " + str(prefix), message) def error(self, message): """ Call ErrorWithPrefix(message). """ self.error_with_prefix("", message) def _get_log_prefix(self, prefix): """ Generates the log prefix with timestamp+'prefix'. """ t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) return t + prefix # meant to be used with tests # noinspection PyMethodMayBeStatic class TestLogger(Logger): def __init__(self): super(Logger, self).__init__() self.verbose = True self.con_path = None self.file_path = None def _log_to_stdout(self, message): sys.stdout.writelines(message) sys.stdout.write("\n") def write_to_file(self, message): self._log_to_stdout(message) def write_to_console(self, message): self._log_to_stdout(message) def log(self, message): self._log_to_stdout(message) def log_to_console(self, message): self._log_to_stdout(message) def log_to_file(self, message): self._log_to_stdout(message) def log_if_verbose(self, message): self._log_to_stdout(message) def log_with_prefix(self, prefix, message): log_prefix = self._get_log_prefix(prefix) for line in message.split('\n'): line = log_prefix + line self._log_to_stdout(line) def log_with_prefix_if_verbose(self, prefix, message): self.log_with_prefix(prefix, message) def warning(self, message): self.log_with_prefix("WARNING:", message) def error_with_prefix(self, prefix, message): self.log_with_prefix("ERROR:", message) def error(self, message): self.error_with_prefix("", message) global global_shared_context_logger try: # test whether global_shared_context_logger has been assigned previously _ = global_shared_context_logger except NameError: # previously not assigned, assign default value # will assign global_shared_context_logger only once global_shared_context_logger = Logger('/var/log/waagent.log', '/dev/console') def log(message): global_shared_context_logger.log(message) def error(message): global_shared_context_logger.error(message) def warning(message): global_shared_context_logger.warning(message) def error_with_prefix(prefix, message): global_shared_context_logger.error_with_prefix(prefix, message) def log_if_verbose(message): global_shared_context_logger.log_if_verbose(message) ================================================ FILE: Utils/ovfutils.py ================================================ import re import os import base64 import xml.dom.minidom import xml.sax.saxutils import Utils.extensionutils as ext_utils import Utils.constants as constants import Utils.logger as logger def get_node_text_data(a): """ Filter non-text nodes from DOM tree """ for b in a.childNodes: if b.nodeType == b.TEXT_NODE: return b.data def translate_custom_data(data, configuration): """ Translate the custom data from a Base64 encoding. Default to no-op. """ data_to_decode = configuration.get("Provisioning.DecodeCustomData") if data_to_decode is not None and data_to_decode.lower().startswith("y"): return base64.b64decode(data) return data class OvfEnv(object): """ Read, and process provisioning info from provisioning file OvfEnv.xml """ # # # # # 1.0 # # LinuxProvisioningConfiguration # HostName # UserName # UserPassword # false # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/authorized_keys # # # # # EB0C0AB4B2D5FC35F2F0658D19F44C8283E2DD62 # $HOME/UserName/.ssh/id_rsa # # # # # # # def __init__(self): """ Reset members. """ self.WaNs = "http://schemas.microsoft.com/windowsazure" self.OvfNs = "http://schemas.dmtf.org/ovf/environment/1" self.MajorVersion = 1 self.MinorVersion = 0 self.ComputerName = None self.AdminPassword = None self.UserName = None self.UserPassword = None self.CustomData = None self.DisableSshPasswordAuthentication = True self.SshPublicKeys = [] self.SshKeyPairs = [] # this is a static function to return an instance of OfvEnv @staticmethod def parse(xml_text, configuration, is_deprovision=False, write_custom_data=True): """ Parse xml tree, retrieving user and ssh key information. Return self. """ ovf_env = OvfEnv() if xml_text is None: return None logger.log_if_verbose(re.sub("UserPassword>.*?<", "UserPassword>*<", xml_text)) try: dom = xml.dom.minidom.parseString(xml_text) except (TypeError, xml.parsers.expat.ExpatError): # when the input is of unexpected type or invalid xml return None if len(dom.getElementsByTagNameNS(ovf_env.OvfNs, "Environment")) != 1: logger.error("Unable to parse OVF XML.") section = None newer = False for p in dom.getElementsByTagNameNS(ovf_env.WaNs, "ProvisioningSection"): for n in p.childNodes: if n.localName == "Version": verparts = get_node_text_data(n).split('.') major = int(verparts[0]) minor = int(verparts[1]) if major > ovf_env.MajorVersion: newer = True if major != ovf_env.MajorVersion: break if minor > ovf_env.MinorVersion: newer = True section = p if newer: logger.warning( "Newer provisioning configuration detected. Please consider updating waagent.") if section is None: logger.error( "Could not find ProvisioningSection with major version=" + str(ovf_env.MajorVersion)) return None ovf_env.ComputerName = get_node_text_data(section.getElementsByTagNameNS(ovf_env.WaNs, "HostName")[0]) ovf_env.UserName = get_node_text_data(section.getElementsByTagNameNS(ovf_env.WaNs, "UserName")[0]) if is_deprovision: return ovf_env try: ovf_env.UserPassword = get_node_text_data(section.getElementsByTagNameNS(ovf_env.WaNs, "UserPassword")[0]) except (KeyError, ValueError, AttributeError, IndexError): pass if write_custom_data: try: cd_section = section.getElementsByTagNameNS(ovf_env.WaNs, "CustomData") if len(cd_section) > 0: ovf_env.CustomData = get_node_text_data(cd_section[0]) if len(ovf_env.CustomData) > 0: ext_utils.set_file_contents(constants.LibDir + '/CustomData', bytearray( translate_custom_data(ovf_env.CustomData, configuration))) logger.log('Wrote ' + constants.LibDir + '/CustomData') else: logger.error(' contains no data!') except Exception as e: logger.error(str(e) + ' occured creating ' + constants.LibDir + '/CustomData') disable_ssh_passwd = section.getElementsByTagNameNS(ovf_env.WaNs, "DisableSshPasswordAuthentication") if len(disable_ssh_passwd) != 0: ovf_env.DisableSshPasswordAuthentication = (get_node_text_data(disable_ssh_passwd[0]).lower() == "true") for pkey in section.getElementsByTagNameNS(ovf_env.WaNs, "PublicKey"): logger.log_if_verbose(repr(pkey)) fp = None path = None for c in pkey.childNodes: if c.localName == "Fingerprint": fp = get_node_text_data(c).upper() logger.log_if_verbose(fp) if c.localName == "Path": path = get_node_text_data(c) logger.log_if_verbose(path) ovf_env.SshPublicKeys += [[fp, path]] for keyp in section.getElementsByTagNameNS(ovf_env.WaNs, "KeyPair"): fp = None path = None logger.log_if_verbose(repr(keyp)) for c in keyp.childNodes: if c.localName == "Fingerprint": fp = get_node_text_data(c).upper() logger.log_if_verbose(fp) if c.localName == "Path": path = get_node_text_data(c) logger.log_if_verbose(path) ovf_env.SshKeyPairs += [[fp, path]] return ovf_env def prepare_dir(self, filepath, distro): """ Create home dir for self.UserName Change owner and return path. """ home = distro.get_home() # Expand HOME variable if present in path path = os.path.normpath(filepath.replace("$HOME", home)) if (not path.startswith("/")) or path.endswith("/"): return None dir_name = path.rsplit('/', 1)[0] if dir_name != "": ext_utils.create_dir(dir_name, "root", 0o700) if path.startswith(os.path.normpath(home + "/" + self.UserName + "/")): ext_utils.create_dir(dir_name, self.UserName, 0o700) return path ================================================ FILE: Utils/test/MockUtil.py ================================================ #!/usr/bin/env python # # Sample Extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class MockUtil(): def __init__(self, test): self.test = test def get_log_dir(self): return "/tmp" def log(self, msg): print(msg) def error(self, msg): print(msg) def get_seq_no(self): return "0" def do_status_report(self, operation, status, status_code, message): self.test.assertNotEqual(None, message) self.last = "do_status_report" def do_exit(self,exit_code,operation,status,code,message): self.test.assertNotEqual(None, message) self.last = "do_exit" ================================================ FILE: Utils/test/env.py ================================================ #!/usr/bin/env python # # Sample Extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os #append installer directory to sys.path root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root) ================================================ FILE: Utils/test/mock.sh ================================================ #!/bin/bash # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. echo "Start..." sleep 0.1 echo "Running" >&2 echo "Warning" sleep 0.1 echo "Finished" exit $1 ================================================ FILE: Utils/test/mock_sshd_config ================================================ # Package generated configuration file # See the sshd_config(5) manpage for details # What ports, IPs and protocols we listen for Port 22 # Use these options to restrict which interfaces/protocols sshd will bind to #ListenAddress :: #ListenAddress 0.0.0.0 Protocol 2 # HostKeys for protocol version 2 HostKey /etc/ssh/ssh_host_rsa_key HostKey /etc/ssh/ssh_host_dsa_key HostKey /etc/ssh/ssh_host_ecdsa_key HostKey /etc/ssh/ssh_host_ed25519_key #Privilege Separation is turned on for security UsePrivilegeSeparation yes # Lifetime and size of ephemeral version 1 server key KeyRegenerationInterval 3600 ServerKeyBits 1024 # Logging SyslogFacility AUTH LogLevel INFO # Authentication: LoginGraceTime 10m PermitRootLogin without-password StrictModes yes RSAAuthentication yes PubkeyAuthentication yes #AuthorizedKeysFile %h/.ssh/authorized_keys # Don’t read the user’s ~/.rhosts and ~/.shosts files IgnoreRhosts yes # For this to work you will also need host keys in /etc/ssh_known_hosts RhostsRSAAuthentication no # similar for protocol version 2 HostbasedAuthentication no # Uncomment if you don’t trust ~/.ssh/known_hosts for RhostsRSAAuthentication #IgnoreUserKnownHosts yes # To enable empty passwords, change to yes (NOT RECOMMENDED) PermitEmptyPasswords no # Change to yes to enable challenge-response passwords (beware issues with # some PAM modules and threads) ChallengeResponseAuthentication yes # Change to no to disable tunnelled clear text passwords PasswordAuthentication no # Kerberos options #KerberosAuthentication no #KerberosGetAFSToken no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes # GSSAPI options #GSSAPIAuthentication no #GSSAPICleanupCredentials yes X11Forwarding yes X11DisplayOffset 10 PrintMotd no PrintLastLog yes TCPKeepAlive yes #UseLogin no #MaxStartups 10:30:60 #Banner /etc/issue.net # Allow client to pass locale environment variables AcceptEnv LANG LC_* Subsystem sftp /usr/lib/openssh/sftp-server # Set this to ‘yes’ to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of “PermitRootLogin without-password”. # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to ‘no’. UsePAM yes # CLOUD_IMG: This file was created/modified by the Cloud Image build process ClientAliveInterval 120 AuthorizedKeysCommand /usr/sbin/aad_certhandler %u %k AuthorizedKeysCommandUser root ================================================ FILE: Utils/test/non_latin_characters.txt ================================================ ü ================================================ FILE: Utils/test/ovf-env-empty.xml ================================================ ]><_/> ================================================ FILE: Utils/test/ovf-env.xml ================================================ 1.0 LinuxProvisioningConfiguration AzureUser true 85C04BB59660B7A2B845DCDD50174B2059CC77A4 /home/AzureUser/.ssh/authorized_keys ubuntu18 1.0 kms.core.windows.net true true true false false ================================================ FILE: Utils/test/place_vmaccess_on_local_machine.sh ================================================ #!/usr/bin/env bash # must run with sudo permissions # this file copies the local changes to /var/lib/waagent/Microsoft.OSTCExtensions.VMAccessForLinux- # remember to update the version number to what you have destdir="/var/lib/waagent/Microsoft.OSTCExtensions.VMAccessForLinux-1.5.4" utilsDest="$destdir/Utils" vmaccessDest="$destdir/vmaccess.py" currentDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" utilsSource="$currentDir/.." vmaccessSource="$currentDir/../../VMAccess/vmaccess.py" cp -r -f $utilsSource $utilsDest cp -f $vmaccessSource $vmaccessDest find $destdir -name '*.pyc' | xargs rm ================================================ FILE: Utils/test/test_encode.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import Utils.extensionutils as eu import unittest class TestEncode(unittest.TestCase): def test_encode(self): contents = eu.get_file_contents('mock_sshd_config') encoded_contents = eu.encode_for_writing_to_file(contents) known_non_ascii_character = b"%c" % encoded_contents[2353] self.assertEqual(known_non_ascii_character, b'\x9d') class TestRunCommandGetOutput(unittest.TestCase): def test_output(self): cmd = ["cat", "non_latin_characters.txt"] return_code, output_string = eu.run_command_get_output(cmd) self.assertEqual(0, return_code) expected_character_byte = b'\xc3\xbc' expected_character = expected_character_byte.decode("utf-8") self.assertEqual(expected_character, output_string[0]) def test_stdin(self): cmd = ['bash', '-c', 'read ; echo $REPLY'] cmd_input = b'\xc3\xbc' # ü character return_code, output_string = eu.run_send_stdin(cmd, cmd_input) self.assertEqual(0, return_code) self.assertEqual(cmd_input.decode('utf-8'), output_string[0]) if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_extensionutils_code_injection.py ================================================ #!/usr/bin/env python import os import pwd import shutil import tempfile import unittest import Utils.extensionutils as ext_utils import Utils.logger as logger logger.global_shared_context_logger = logger.TestLogger() class TestCodeInjection(unittest.TestCase): test_dir = "./test_output" def get_random_filename(self): f = tempfile.NamedTemporaryFile(dir=TestCodeInjection.test_dir, delete=False) return f.name def cleanup(self): shutil.rmtree(TestCodeInjection.test_dir) def setup(self): current_user = pwd.getpwuid(os.getuid()) ext_utils.create_dir(TestCodeInjection.test_dir, current_user.pw_name, 0o700) def test_code_injection(self): # failure cases exit_code, string_output = ext_utils.run_command_get_output("echo hello; echo world") self.assertNotEqual(0, exit_code, "exit code != 0") exit_code, string_output = ext_utils.run_command_get_output(["echo hello; echo world"]) self.assertNotEqual(0, exit_code, "exit code != 0") # success case exit_code, string_output = ext_utils.run_command_get_output(["echo", "hello", ";", "echo", "world"]) self.assertEqual(0, exit_code, "exit code == 0") self.assertEqual("hello ; echo world\n", string_output, "unexpected output") exit_code, string_output = ext_utils.run_command_get_output(["echo", "hello", "world"]) self.assertEqual(0, exit_code, "exit code == 0") def test_code_injection2(self): self.setup() self.addCleanup(self.cleanup) # failure cases out_file = self.get_random_filename() exit_code = ext_utils.run_command_and_write_stdout_to_file( "echo hello; echo world", out_file) self.assertNotEqual(0, exit_code, "exit code != 0") out_file = self.get_random_filename() exit_code = ext_utils.run_command_and_write_stdout_to_file( ["echo hello; echo world"], out_file) self.assertNotEqual(0, exit_code, "exit code != 0") # success case out_file = self.get_random_filename() exit_code = ext_utils.run_command_and_write_stdout_to_file( ["echo", "hello", ";", "echo", "world"], out_file) self.assertEqual(0, exit_code, "exit code == 0") file_contents = ext_utils.get_file_contents(out_file) self.assertEqual("hello ; echo world\n", file_contents, "unexpected output") out_file = self.get_random_filename() exit_code = ext_utils.run_command_and_write_stdout_to_file([ "echo", "hello", "world"], out_file) self.assertEqual(0, exit_code, "exit code == 0") file_contents = ext_utils.get_file_contents(out_file) self.assertEqual("hello world\n", file_contents, "unexpected output") if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_logutil.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import LogUtil as lu class TestLogUtil(unittest.TestCase): def test_tail(self): with open("/tmp/testtail", "w+") as F: F.write(u"abcdefghijklmnopqrstu\u6211vwxyz".encode("utf-8")) tail = lu.tail("/tmp/testtail", 2) self.assertEquals("yz", tail) tail = lu.tail("/tmp/testtail") self.assertEquals("abcdefghijklmnopqrstuvwxyz", tail) if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_null_protected_settings.py ================================================ #!/usr/bin/env python # # Sample Extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import HandlerUtil as Util def mock_log(*args, **kwargs): pass class TestNullProtectedSettings(unittest.TestCase): def test_null_protected_settings(self): hutil = Util.HandlerUtility(mock_log, mock_log, "UnitTest", "HandlerUtil.UnitTest", "0.0.1") config = hutil._parse_config(Settings) handlerSettings = config['runtimeSettings'][0]['handlerSettings'] self.assertEquals(handlerSettings["protectedSettings"], None) Settings="""\ { "runtimeSettings":[{ "handlerSettings":{ "protectedSettingsCertThumbprint":null, "protectedSettings":null, "publicSettings":{} } }] } """ if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_ovf_utils.py ================================================ #!/usr/bin/env python import os.path as path import unittest import Utils.extensionutils as ext_utils import Utils.ovfutils as ovf_utils import Utils.logger as logger # dummy configuration class based on vmaccess.Configuration class Configuration: def __init__(self): self.dictionary = { "Provisioning.DecodeCustomData": "n" } def get(self, key): return self.dictionary.get(key) config = Configuration() logger.global_shared_context_logger = logger.TestLogger() class TestTestOvfUtils(unittest.TestCase): def test_ovf_env_parse(self): current_dir = path.dirname(path.abspath(__file__)) ovf_xml = ext_utils.get_file_contents(path.join(current_dir, 'ovf-env.xml')) ovf_env = ovf_utils.OvfEnv.parse(ovf_xml, config) self.assertIsNotNone(ovf_env, "ovf_env should not be null") def test_ovf_env_parse_minimalxml(self): current_dir = path.dirname(path.abspath(__file__)) ovf_xml = ext_utils.get_file_contents(path.join(current_dir, 'ovf-env-empty.xml')) ovf_env = ovf_utils.OvfEnv.parse(ovf_xml, config) self.assertIsNone(ovf_env, "ovf_env should be null") def test_ovf_env_parse_none_string(self): ovf_env = ovf_utils.OvfEnv.parse(None, config) self.assertIsNone(ovf_env, "ovf_env should be null") def test_ovf_env_parse_empty_string(self): ovf_env = ovf_utils.OvfEnv.parse("", config) self.assertIsNone(ovf_env, "ovf_env should be null") if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_redacted_settings.py ================================================ #!/usr/bin/env python # # Tests for redacted settings # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import Utils.HandlerUtil as Util class TestRedactedProtectedSettings(unittest.TestCase): def test_redacted_protected_settings(self): redacted = Util.HandlerUtility.redact_protected_settings(settings_original) self.assertIn('"protectedSettings": "*** REDACTED ***"', redacted) self.assertIn('"protectedSettingsCertThumbprint": "*** REDACTED ***"', redacted) settings_original = """\ { "runtimeSettings": [{ "handlerSettings": { "protectedSettingsCertThumbprint": "9310D2O49D7216D4A1CEDCE9D8A7CE5DBD7FB7BF", "protectedSettings": "MIIC4AYJKoZIhvcNAQcWoIIB0TCDEc0CAQAxggFpMIIBZQIBADBNMDkxNzA1BgoJkiaJk/IsZAEZFidXaW5kb3dzIEF6dXJlIENSUCBDZXJ0aWZpY2F0ZSBHZW5lcmF0b3ICEB8f7DyzHLGjSDLnEWd4YeAwDQYJKoZIhvcNAQEBBQAEggEAiZj2gQtT4MpdTaEH8rUVFB/8Ucc8OxGFWu8VKbIdoHLKp1WcDb7Vlzv6fHLBIccgXGuR1XHTvtlD4QiKpSet341tPPug/R5ZtLSRz1pqtXZdrFcuuSxOa6ib/+la5ukdygcVwkEnmNSQaiipPKyqPH2JsuhmGCdXFiKwCSTrgGE6GyCBtaK9KOf48V/tYXHnDGrS9q5a1gRF5KVI2B26UYSO7V7pXjzYCd/Sp9yGj7Rw3Kqf9Lpix/sPuqWjV6e2XFlD3YxaHSeHVnLI/Bkz2E6Ri8yfPYus52r/mECXPL2YXqY9dGyrlKKIaD9AuzMyvvy1A74a9VBq7zxQQ4adEzBbBgkqhkiG9w0BBwEwFAYIKoZIhvcNAwcECDyEf4mRrmWJgDhW4j2nRNTJU4yXxocQm/PhAr39Um7n0pgI2Cn28AabYtsHWjKqr8Al9LX6bKm8cnmnLjqTntphCw==", "publicSettings": {} } }] } """ if __name__ == '__main__': unittest.main() ================================================ FILE: Utils/test/test_scriptutil.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import env import ScriptUtil as su import unittest from MockUtil import MockUtil class TestScriptUtil(unittest.TestCase): def test_parse_args(self): print(__file__) cmd = u'sh foo.bar.sh -af bar --foo=bar | more \u6211' args = su.parse_args(cmd.encode('utf-8')) self.assertNotEquals(None, args) self.assertNotEquals(0, len(args)) print(args) def test_run_command(self): hutil = MockUtil(self) test_script = "mock.sh" os.chdir(os.path.join(env.root, "test")) exit_code = su.run_command(hutil, ["sh", test_script, "0"], os.getcwd(), 'RunScript-0', 'TestExtension', '1.0', True, 0.1) self.assertEquals(0, exit_code) self.assertEquals("do_exit", hutil.last) exit_code = su.run_command(hutil, ["sh", test_script, "75"], os.getcwd(), 'RunScript-1', 'TestExtension', '1.0', False, 0.1) self.assertEquals(75, exit_code) self.assertEquals("do_status_report", hutil.last) def test_log_or_exit(self): hutil = MockUtil(self) su.log_or_exit(hutil, True, 0, 'LogOrExit-0', 'Message1') self.assertEquals("do_exit", hutil.last) su.log_or_exit(hutil, False, 0, 'LogOrExit-1', 'Message2') self.assertEquals("do_status_report", hutil.last) if __name__ == '__main__': unittest.main() ================================================ FILE: VMAccess/CHANGELOG.md ================================================ ## 1.5.10 (2020-09-09) - VMAccess Linux is now more robust to the absence of ovf-env.xml file ## 1.5.6 - 1.5.9 - several bug-fixes ## 1.5.5 (2020-07-20) - Created new python modules under Utils that are meant to be python 3 compatible and are supposed to be used instead of importing waagent python file through waagentloader.py - Fixed code injection vulnerability through the username ## 1.5.1 (2018-10-31) - Support for Python3. Changing VMAccess to work for both Python 2 and Python 3 interpreter. ## 1.4.6.0 (2016-09-16) - Forcibly reset ChallengeAuthenticationResponse to no. This value was inadvertently set in previous releases, and is forcibly reset. ## 1.4.5.0 (2016-09-07) - Check for None before checking the length of a user's password. This is fallout from allowing and rejecting empty passwords. ## 1.4.4.0 (2016-09-06) - Do not set ChallengeResponseAuthenticaiton. This value should not be changed by VMAccess. ## 1.4.3.0 (2016-09-05) - Reject zero length passwords. ## 1.4.2.0 (2016-08-25) - Ensure expiration (if specified) is used when creating an account - Backup sshd_config before any edits are made. - Ensure sshd_config is restarted when edits are made. ## 1.4.1.0 (2016-07-27) - Install operation posts incorrect status [#206] - Misspelling of resources/debian_default ================================================ FILE: VMAccess/HandlerManifest.json ================================================ [ { "version": 1.0, "handlerManifest": { "disableCommand": "extension_noop.sh", "enableCommand": "extension_shim.sh -c ./vmaccess.py -e", "installCommand": "extension_noop.sh", "uninstallCommand": "extension_noop.sh", "updateCommand": "extension_noop.sh", "rebootAfterInstall": false, "reportHeartbeat": false, "continueOnUpdateFailure": true } } ] ================================================ FILE: VMAccess/README.md ================================================ # VMAccess Extension Provide several ways to allow owner of the VM to get the SSH access back and perform additional VM disk check tasks. Current version is [1.5](https://github.com/Azure/azure-linux-extensions/releases/tag/VMAccess-1.5.18). You can read the User Guide below. * [Manage administrative users, SSH, and check or repair disks on Linux VMs by using the VMAccess extension](https://learn.microsoft.com/en-us/azure/virtual-machines/extensions/vmaccess) VMAccess Extension can: * Reset the password of the original sudo user * Create a new sudo user with the password specified * Set the public host key with the key given * Reset the public host key provided during VM provisioning if host key not provided * Open the SSH port(22) and reset the sshd_config if reset_ssh is set to true * Remove the existing user * Check disks * Repair added disk * Remove prior public keys when a new public key is provided * Restore the original backup sshd_config if restore_backup_ssh is set to true # Security Notes: * VMAccess Extension is designed for regaining access to a VM in the event that access is lost. * Based on this principle, it will grant sudo permission to the account specified in the username field. * Do not specify a user in the username field if you do not wish that user to gain sudo permissions. * Instead, login to the VM and use built-in tools (e.g. usermod, chage, etc) to manage unprivileged users. # User Guide ## 1. Configuration schema ### 1.1. Public configuration Schema for the public configuration file looks like: * `check_disk`: (optional, boolean) whether or not to check disk * `repair_disk`: (optional, boolean) whether or not to repair disk * `disk_name`: (boolean) name of disk to repair (required when repair_disk is true) ```json { "check_disk": "true", "repair_disk": "true", "disk_name": "" } ``` ### 1.2. Protected configuration Schema for the protected configuration file looks like this: * `username`: (required, string) the name of the user * `password`: (optional, string) the password of the user * `ssh_key`: (optional, string) the public key of the user * `reset_ssh`: (optional, boolean) whether or not reset the ssh * `remove_user`: (optional, string) the user name to remove * `expiration`: (optional, string) expiration of the account, defaults to never, e.g. 2016-01-01. * `remove_prior_keys`: (optional, boolean) whether or not to remove old SSH keys when adding a new one * `restore_backup_ssh`: (optional, boolean) whether or not to restore original backed-up sshd config ```json { "username": "", "password": "", "ssh_key": "", "reset_ssh": true, "remove_user": "", "expiration": "", "remove_prior_keys": true, "restore_backup_ssh": true } ``` `ssh_key` supports `ssh-rsa`, `ssh-ed25519` and `.pem` formats. * If your public key is in `ssh-rsa` format, for example, `ssh-rsa XXXXXXXX`, you can use: ``` "ssh_key": "ssh-rsa XXXXXXXX" ``` * If your public key is in `ssh-ed25519` format, for example, `ssh-ed25519 XXXXXXXX`, you can use: ``` "ssh_key": "ssh-ed25519 XXXXXXXX" ``` * If your public key is in `.pem` format, use the following UNIX command to convert the .pem file to a value that can be passed in a JSON string: ``` awk 'NF {sub(/\r/, ""); printf "%s\\n",$0;}' myCert.pem ``` You can use: ``` "ssh_key": "-----BEGIN CERTIFICATE-----\nXXXXXXXXXXXXXXXXXXXXXXXX\n-----END CERTIFICATE-----" ``` ## 2. Deploying the Extension to a VM You can deploy it using Azure CLI, Azure Powershell and ARM template. ### 2.1. Using [**Azure CLI**][azure-cli] Create a `settings.json` (optional) and a `protected_settings.json` and run: ``` $ azure vm extension set \ --resource-group \ --vm-name \ --name VMAccessForLinux \ --publisher Microsoft.OSTCExtensions \ --version 1.5 \ --settings settings.json --protected-settings protected_settings.json ``` To retrieve the deployment state of extensions for a given VM, run: ``` $ azure vm extension list \ --resource-group \ --vm-name -o table ``` ### 2.2. Using [**Azure Powershell**][azure-powershell] You can deploying VMAccess Extension by running: ```powershell $username = "" $sshKey = "" $settings = @{"check_disk" = $true}; $protectedSettings = @{"username" = $username; "ssh_key" = $sshKey}; Set-AzVMExtension -ResourceGroupName "" -VMName "" -Location "" ` -Publisher "Microsoft.OSTCExtensions" -ExtensionType "VMAccessForLinux" -Name "VMAccessForLinux" ` -TypeHandlerVersion "1.5" -Settings $settings -ProtectedSettings $protectedSettings ``` You can provide and modify extension settings by using strings: ```powershell $username = "" $sshKey = "" $settingsString = '{"check_disk":true}'; $protectedSettingsString = '{"username":"' + $username + '","ssh_key":"' + $sshKey + '"}'; Set-AzVMExtension -ResourceGroupName "" -VMName "" -Location "" ` -Publisher "Microsoft.OSTCExtensions" -ExtensionType "VMAccessForLinux" -Name "VMAccessForLinux" ` -TypeHandlerVersion "1.5" -SettingString $settingsString -ProtectedSettingString $protectedSettingsString ``` ### 2.3. Using [**ARM Template**][arm-template] ```json { "type": "Microsoft.Compute/virtualMachines/extensions", "name": "", "apiVersion": "", "location": "", "dependsOn": [ "[concat('Microsoft.Compute/virtualMachines/', )]" ], "properties": { "publisher": "Microsoft.OSTCExtensions", "type": "VMAccessForLinux", "typeHandlerVersion": "1.5", "autoUpgradeMinorVersion": true, "settings": {}, "protectedSettings": { "username": "", "password": "", "reset_ssh": true, "ssh_key": "", "remove_user": "" } } } ``` Refer to the following sample [ARM template](https://github.com/azure/azure-quickstart-templates/tree/master/demos/vmaccess-on-ubuntu). For more details about ARM template, please visit [Authoring Azure Resource Manager templates](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authoring-templates/). ## 3. Scenarios ### 3.1 Resetting the password in the Public Settings ```json { "check_disk": "false" } ``` > VMAccessForLinux resets and restarts the SSH server if a password is specified. This is necessary if the VM was deployed with public key authentication because the SSH server is not configured to accept passwords. For this reason, the SSH server's configuration is reset to allow password authentication, and restarted to accept this new configuration. This behavior can be disabled by setting the reset_ssh value to false. in the Protected Settings ```json { "username": "currentusername", "password": "newpassword", "reset_ssh": "false" } ``` ### 3.2 Resetting the SSH key ```json { "username": "currentusername", "ssh_key": "contentofsshkey" } ``` ### 3.3 Resetting the password and the SSH key ```json { "username": "currentusername", "ssh_key": "contentofsshkey", "password": "newpassword", } ``` ### 3.4 Creating a new sudo user account with the password ```json { "username": "newusername", "password": "newpassword" } ``` #### 3.4.1 Creating a new sudo user account with a password and expiration date. ```json { "username": "newusername", "password": "newpassword", "expiration": "2016-12-31" } ``` ### 3.5 Creating a new sudo user account with the SSH key ```json { "username": "newusername", "ssh_key": "contentofsshkey" } ``` #### 3.5.1 Creating a new sudo user account with the SSH key ```json { "username": "newusername", "ssh_key": "contentofsshkey", "expiration": "2016-12-31" } ``` ### 3.6 Resetting the SSH configuration ```json { "reset_ssh": true } ``` ### 3.7 Removing an existing user ```json { "remove_user": "usertoberemoveed", } ``` ### 3.8 Checking added disks on VM ```json { "check_disk": "true" } ``` ### 3.9 Fix added disks on a VM ```json { "repair_disk": "true", "disk_name": "userdisktofix" } ``` ### 3.10 Removing prior SSH keys (only when provided a new one) ```json { "username": "newusername", "ssh_key": "contentofsshkey", "remove_prior_keys": true } ``` ### 3.11 Restoring original SSH configuration ```json { "restore_backup_ssh": true } ``` ## Supported Linux Distributions - Ubuntu 12.04 and higher - CentOS 6.5 and higher - Oracle Linux 6.4.0.0.0 and higher - openSUSE 13.1 and higher - SUSE Linux Enterprise Server 11 SP3 and higher ## Debug * The status of the extension is reported back to Azure so that user can see the status on Azure Portal * The operation log of the extension is `/var/log/azure///extension.log` file. [azure-powershell]: https://azure.microsoft.com/en-us/documentation/articles/powershell-install-configure/ [azure-cli]: https://azure.microsoft.com/en-us/documentation/articles/xplat-cli/ [arm-template]: http://azure.microsoft.com/en-us/documentation/templates/ [arm-overview]: https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/ ================================================ FILE: VMAccess/extension_noop.sh ================================================ #!/usr/bin/env bash # There is no need to write a status file for commands other than Enable exit 0 ================================================ FILE: VMAccess/extension_shim.sh ================================================ #!/usr/bin/env bash # Keeping the default command COMMAND="" PYTHON="" USAGE="$(basename "$0") [-h] [-i|--install] [-u|--uninstall] [-d|--disable] [-e|--enable] [-p|--update] Program to find the installed python on the box and invoke a Python extension script. where: -h|--help show this help text -i|--install install the extension -u|--uninstall uninstall the extension -d|--disable disable the extension -e|--enable enable the extension -p|--update update the extension -c|--command command to run example: # Install usage $ bash extension_shim.sh -i python ./vmaccess.py -install # Custom executable python file $ bash extension_shim.sh -c ""hello.py"" -i python hello.py -install # Custom executable python file with arguments $ bash extension_shim.sh -c ""hello.py --install"" python hello.py --install " function find_python(){ local python_exec_command=$1 # Check if there is python defined. if command -v python >/dev/null 2>&1 ; then eval ${python_exec_command}="python" else # Python was not found. Searching for Python3 now. if command -v python3 >/dev/null 2>&1 ; then eval ${python_exec_command}="python3" fi fi } # Transform long options to short ones for getopts support (getopts doesn't support long args) for arg in "$@"; do shift case "$arg" in "--help") set -- "$@" "-h" ;; "--install") set -- "$@" "-i" ;; "--update") set -- "$@" "-p" ;; "--enable") set -- "$@" "-e" ;; "--disable") set -- "$@" "-d" ;; "--uninstall") set -- "$@" "-u" ;; *) set -- "$@" "$arg" esac done if [ -z "$arg" ] then echo "$USAGE" >&2 exit 1 fi # Get the arguments while getopts "iudephc:?" o; do case "${o}" in h|\?) echo "$USAGE" exit 0 ;; i) operation="-install" ;; u) operation="-uninstall" ;; d) operation="-disable" ;; e) operation="-enable" ;; p) operation="-update" ;; c) COMMAND="$OPTARG" ;; *) echo "$USAGE" >&2 exit 1 ;; esac done shift $((OPTIND-1)) # If find_python is not able to find a python installed, $PYTHON will be null. find_python PYTHON if [ -z "$PYTHON" ]; then echo "No Python interpreter found on the box" >&2 exit 51 # Not Supported else echo `${PYTHON} --version` fi ${PYTHON} ${COMMAND} ${operation} # DONE ================================================ FILE: VMAccess/manifest.xml ================================================ Microsoft.OSTCExtensions VMAccessForLinux 1.5.23 VmRole Microsoft Azure VM Access Extension for Linux Virtual Machines true https://github.com/Azure/azure-linux-extensions/blob/master/LICENSE-2_0.txt http://www.microsoft.com/privacystatement/en-us/OnlineServices/Default.aspx https://github.com/Azure/azure-linux-extensions true Linux Microsoft ================================================ FILE: VMAccess/references ================================================ Utils/ Common/WALinuxAgent-2.0.16/waagent Common/waagentloader.py ================================================ FILE: VMAccess/resources/SuSE_default ================================================ # $OpenBSD: sshd_config,v 1.89 2013/02/06 00:20:42 dtucker Exp $ # This is the sshd server system-wide configuration file. See # sshd_config(5) for more information. # This sshd was compiled with PATH=/usr/bin:/bin:/usr/sbin:/sbin # The strategy used for options in the default sshd_config shipped with # OpenSSH is to specify options with their default value where # possible, but leave them commented. Uncommented options override the # default value. #Port 22 #AddressFamily any #ListenAddress 0.0.0.0 #ListenAddress :: # The default requires explicit activation of protocol 1 #Protocol 2 # HostKey for protocol version 1 #HostKey /etc/ssh/ssh_host_key # HostKeys for protocol version 2 #HostKey /etc/ssh/ssh_host_rsa_key #HostKey /etc/ssh/ssh_host_dsa_key #HostKey /etc/ssh/ssh_host_ecdsa_key # Lifetime and size of ephemeral version 1 server key #KeyRegenerationInterval 1h #ServerKeyBits 1024 # Logging # obsoletes QuietMode and FascistLogging #SyslogFacility AUTH #LogLevel INFO # Authentication: #LoginGraceTime 2m #PermitRootLogin yes #StrictModes yes #MaxAuthTries 6 #MaxSessions 10 #RSAAuthentication yes #PubkeyAuthentication yes # The default is to check both .ssh/authorized_keys and .ssh/authorized_keys2 # but this is overridden so installations will only check .ssh/authorized_keys AuthorizedKeysFile .ssh/authorized_keys #AuthorizedPrincipalsFile none #AuthorizedKeysCommand none #AuthorizedKeysCommandUser nobody # For this to work you will also need host keys in /etc/ssh/ssh_known_hosts #RhostsRSAAuthentication no # similar for protocol version 2 #HostbasedAuthentication no # Change to yes if you don't trust ~/.ssh/known_hosts for # RhostsRSAAuthentication and HostbasedAuthentication #IgnoreUserKnownHosts no # Don't read the user's ~/.rhosts and ~/.shosts files #IgnoreRhosts yes # To disable tunneled clear text passwords, change to no here! PasswordAuthentication no #PermitEmptyPasswords no # Change to no to disable s/key passwords #ChallengeResponseAuthentication yes # Kerberos options #KerberosAuthentication no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes #KerberosGetAFSToken no # GSSAPI options #GSSAPIAuthentication no #GSSAPICleanupCredentials yes # Set this to 'yes' to enable support for the deprecated 'gssapi' authentication # mechanism to OpenSSH 3.8p1. The newer 'gssapi-with-mic' mechanism is included # in this release. The use of 'gssapi' is deprecated due to the presence of # potential man-in-the-middle attacks, which 'gssapi-with-mic' is not susceptible to. #GSSAPIEnableMITMAttack no # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. UsePAM yes #AllowAgentForwarding yes #AllowTcpForwarding yes #GatewayPorts no X11Forwarding yes #X11DisplayOffset 10 #X11UseLocalhost yes #PrintMotd yes #PrintLastLog yes #TCPKeepAlive yes #UseLogin no UsePrivilegeSeparation sandbox # Default for new installations. #PermitUserEnvironment no #Compression delayed #ClientAliveInterval 0 #ClientAliveCountMax 3 #UseDNS yes #PidFile /run/sshd.pid #MaxStartups 10:30:100 #PermitTunnel no #ChrootDirectory none #VersionAddendum none # no default banner path #Banner none # override default of no subsystems Subsystem sftp /usr/lib/ssh/sftp-server # This enables accepting locale enviroment variables LC_* LANG, see sshd_config(5). AcceptEnv LANG LC_CTYPE LC_NUMERIC LC_TIME LC_COLLATE LC_MONETARY LC_MESSAGES AcceptEnv LC_PAPER LC_NAME LC_ADDRESS LC_TELEPHONE LC_MEASUREMENT AcceptEnv LC_IDENTIFICATION LC_ALL # Example of overriding settings on a per-user basis #Match User anoncvs # X11Forwarding no # AllowTcpForwarding no # ForceCommand cvs server ClientAliveInterval 180 ================================================ FILE: VMAccess/resources/Ubuntu_default ================================================ # Package generated configuration file # See the sshd_config(5) manpage for details # What ports, IPs and protocols we listen for Port 22 # Use these options to restrict which interfaces/protocols sshd will bind to #ListenAddress :: #ListenAddress 0.0.0.0 Protocol 2 # HostKeys for protocol version 2 HostKey /etc/ssh/ssh_host_rsa_key HostKey /etc/ssh/ssh_host_dsa_key HostKey /etc/ssh/ssh_host_ecdsa_key HostKey /etc/ssh/ssh_host_ed25519_key #Privilege Separation is turned on for security UsePrivilegeSeparation yes # Lifetime and size of ephemeral version 1 server key KeyRegenerationInterval 3600 ServerKeyBits 1024 # Logging SyslogFacility AUTH LogLevel INFO # Authentication: LoginGraceTime 120 PermitRootLogin without-password StrictModes yes RSAAuthentication yes PubkeyAuthentication yes #AuthorizedKeysFile %h/.ssh/authorized_keys # Don't read the user's ~/.rhosts and ~/.shosts files IgnoreRhosts yes # For this to work you will also need host keys in /etc/ssh_known_hosts RhostsRSAAuthentication no # similar for protocol version 2 HostbasedAuthentication no # Uncomment if you don't trust ~/.ssh/known_hosts for RhostsRSAAuthentication #IgnoreUserKnownHosts yes # To enable empty passwords, change to yes (NOT RECOMMENDED) PermitEmptyPasswords no # Change to yes to enable challenge-response passwords (beware issues with # some PAM modules and threads) ChallengeResponseAuthentication no # Change to no to disable tunnelled clear text passwords PasswordAuthentication yes # Kerberos options #KerberosAuthentication no #KerberosGetAFSToken no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes # GSSAPI options #GSSAPIAuthentication no #GSSAPICleanupCredentials yes X11Forwarding yes X11DisplayOffset 10 PrintMotd no PrintLastLog yes TCPKeepAlive yes #UseLogin no #MaxStartups 10:30:60 #Banner /etc/issue.net # Allow client to pass locale environment variables AcceptEnv LANG LC_* Subsystem sftp /usr/lib/openssh/sftp-server # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. UsePAM yes # CLOUD_IMG: This file was created/modified by the Cloud Image build process ClientAliveInterval 120 ================================================ FILE: VMAccess/resources/centos_default ================================================ # $OpenBSD: sshd_config,v 1.80 2008/07/02 02:24:18 djm Exp $ # This is the sshd server system-wide configuration file. See # sshd_config(5) for more information. # This sshd was compiled with PATH=/usr/local/bin:/bin:/usr/bin # The strategy used for options in the default sshd_config shipped with # OpenSSH is to specify options with their default value where # possible, but leave them commented. Uncommented options change a # default value. #Port 22 #AddressFamily any #ListenAddress 0.0.0.0 #ListenAddress :: # Disable legacy (protocol version 1) support in the server for new # installations. In future the default will change to require explicit # activation of protocol 1 Protocol 2 # HostKey for protocol version 1 #HostKey /etc/ssh/ssh_host_key # HostKeys for protocol version 2 #HostKey /etc/ssh/ssh_host_rsa_key #HostKey /etc/ssh/ssh_host_dsa_key # Lifetime and size of ephemeral version 1 server key #KeyRegenerationInterval 1h #ServerKeyBits 1024 # Logging # obsoletes QuietMode and FascistLogging #SyslogFacility AUTH SyslogFacility AUTHPRIV #LogLevel INFO # Authentication: #LoginGraceTime 2m #PermitRootLogin yes #StrictModes yes #MaxAuthTries 6 #MaxSessions 10 #RSAAuthentication yes #PubkeyAuthentication yes #AuthorizedKeysFile .ssh/authorized_keys #AuthorizedKeysCommand none #AuthorizedKeysCommandRunAs nobody # For this to work you will also need host keys in /etc/ssh/ssh_known_hosts #RhostsRSAAuthentication no # similar for protocol version 2 #HostbasedAuthentication no # Change to yes if you don't trust ~/.ssh/known_hosts for # RhostsRSAAuthentication and HostbasedAuthentication #IgnoreUserKnownHosts no # Don't read the user's ~/.rhosts and ~/.shosts files #IgnoreRhosts yes # To disable tunneled clear text passwords, change to no here! #PasswordAuthentication yes #PermitEmptyPasswords no PasswordAuthentication yes # Change to no to disable s/key passwords #ChallengeResponseAuthentication yes ChallengeResponseAuthentication no # Kerberos options #KerberosAuthentication no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes #KerberosGetAFSToken no #KerberosUseKuserok yes # GSSAPI options #GSSAPIAuthentication no GSSAPIAuthentication yes #GSSAPICleanupCredentials yes GSSAPICleanupCredentials yes #GSSAPIStrictAcceptorCheck yes #GSSAPIKeyExchange no # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. #UsePAM no UsePAM yes # Accept locale-related environment variables AcceptEnv LANG LC_CTYPE LC_NUMERIC LC_TIME LC_COLLATE LC_MONETARY LC_MESSAGES AcceptEnv LC_PAPER LC_NAME LC_ADDRESS LC_TELEPHONE LC_MEASUREMENT AcceptEnv LC_IDENTIFICATION LC_ALL LANGUAGE AcceptEnv XMODIFIERS #AllowAgentForwarding yes #AllowTcpForwarding yes #GatewayPorts no #X11Forwarding no X11Forwarding yes #X11DisplayOffset 10 #X11UseLocalhost yes #PrintMotd yes #PrintLastLog yes #TCPKeepAlive yes #UseLogin no #UsePrivilegeSeparation yes #PermitUserEnvironment no #Compression delayed ClientAliveInterval 180 #ClientAliveCountMax 3 #ShowPatchLevel no #UseDNS yes #PidFile /var/run/sshd.pid #MaxStartups 10:30:100 #PermitTunnel no #ChrootDirectory none # no default banner path #Banner none # override default of no subsystems Subsystem sftp /usr/libexec/openssh/sftp-server # Example of overriding settings on a per-user basis #Match User anoncvs # X11Forwarding no # AllowTcpForwarding no # ForceCommand cvs server ================================================ FILE: VMAccess/resources/debian_default ================================================ # Package generated configuration file # See the sshd_config(5) manpage for details # What ports, IPs and protocols we listen for Port 22 # Use these options to restrict which interfaces/protocols sshd will bind to #ListenAddress :: #ListenAddress 0.0.0.0 Protocol 2 # HostKeys for protocol version 2 HostKey /etc/ssh/ssh_host_rsa_key HostKey /etc/ssh/ssh_host_dsa_key HostKey /etc/ssh/ssh_host_ecdsa_key HostKey /etc/ssh/ssh_host_ed25519_key #Privilege Separation is turned on for security UsePrivilegeSeparation yes # Lifetime and size of ephemeral version 1 server key KeyRegenerationInterval 3600 ServerKeyBits 1024 # Logging SyslogFacility AUTH LogLevel INFO # Authentication: LoginGraceTime 120 PermitRootLogin without-password StrictModes yes RSAAuthentication yes PubkeyAuthentication yes #AuthorizedKeysFile %h/.ssh/authorized_keys # Don't read the user's ~/.rhosts and ~/.shosts files IgnoreRhosts yes # For this to work you will also need host keys in /etc/ssh_known_hosts RhostsRSAAuthentication no # similar for protocol version 2 HostbasedAuthentication no # Uncomment if you don't trust ~/.ssh/known_hosts for RhostsRSAAuthentication #IgnoreUserKnownHosts yes # To enable empty passwords, change to yes (NOT RECOMMENDED) PermitEmptyPasswords no # Change to yes to enable challenge-response passwords (beware issues with # some PAM modules and threads) ChallengeResponseAuthentication no # Change to no to disable tunnelled clear text passwords PasswordAuthentication yes # Kerberos options #KerberosAuthentication no #KerberosGetAFSToken no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes # GSSAPI options #GSSAPIAuthentication no #GSSAPICleanupCredentials yes X11Forwarding yes X11DisplayOffset 10 PrintMotd no PrintLastLog yes TCPKeepAlive yes #UseLogin no #MaxStartups 10:30:60 #Banner /etc/issue.net # Allow client to pass locale environment variables AcceptEnv LANG LC_* Subsystem sftp /usr/lib/openssh/sftp-server # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. UsePAM yes # CLOUD_IMG: This file was created/modified by the Cloud Image build process ClientAliveInterval 120 ================================================ FILE: VMAccess/resources/default ================================================ #Default sshd_config # Package generated configuration file # See the sshd_config(5) manpage for details # What ports, IPs and protocols we listen for Port 22 # Use these options to restrict which interfaces/protocols sshd will bind to #ListenAddress :: #ListenAddress 0.0.0.0 Protocol 2 # HostKeys for protocol version 2 HostKey /etc/ssh/ssh_host_rsa_key HostKey /etc/ssh/ssh_host_dsa_key HostKey /etc/ssh/ssh_host_ecdsa_key HostKey /etc/ssh/ssh_host_ed25519_key #Privilege Separation is turned on for security UsePrivilegeSeparation yes # Lifetime and size of ephemeral version 1 server key KeyRegenerationInterval 3600 ServerKeyBits 1024 # Logging SyslogFacility AUTH LogLevel INFO # Authentication: LoginGraceTime 120 PermitRootLogin without-password StrictModes yes RSAAuthentication yes PubkeyAuthentication yes #AuthorizedKeysFile %h/.ssh/authorized_keys # Don't read the user's ~/.rhosts and ~/.shosts files IgnoreRhosts yes # For this to work you will also need host keys in /etc/ssh_known_hosts RhostsRSAAuthentication no # similar for protocol version 2 HostbasedAuthentication no # Uncomment if you don't trust ~/.ssh/known_hosts for RhostsRSAAuthentication #IgnoreUserKnownHosts yes # To enable empty passwords, change to yes (NOT RECOMMENDED) PermitEmptyPasswords no # Change to yes to enable challenge-response passwords (beware issues with # some PAM modules and threads) ChallengeResponseAuthentication no # Change to no to disable tunnelled clear text passwords PasswordAuthentication yes # Kerberos options #KerberosAuthentication no #KerberosGetAFSToken no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes # GSSAPI options #GSSAPIAuthentication no #GSSAPICleanupCredentials yes X11Forwarding yes X11DisplayOffset 10 PrintMotd no PrintLastLog yes TCPKeepAlive yes #UseLogin no #MaxStartups 10:30:60 #Banner /etc/issue.net # Allow client to pass locale environment variables AcceptEnv LANG LC_* Subsystem sftp /usr/lib/openssh/sftp-server # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. UsePAM yes # CLOUD_IMG: This file was created/modified by the Cloud Image build process ClientAliveInterval 120 ================================================ FILE: VMAccess/resources/fedora_default ================================================ # $OpenBSD: sshd_config,v 1.80 2008/07/02 02:24:18 djm Exp $ # This is the sshd server system-wide configuration file. See # sshd_config(5) for more information. # This sshd was compiled with PATH=/usr/local/bin:/bin:/usr/bin # The strategy used for options in the default sshd_config shipped with # OpenSSH is to specify options with their default value where # possible, but leave them commented. Uncommented options change a # default value. #Port 22 #AddressFamily any #ListenAddress 0.0.0.0 #ListenAddress :: # Disable legacy (protocol version 1) support in the server for new # installations. In future the default will change to require explicit # activation of protocol 1 Protocol 2 # HostKey for protocol version 1 #HostKey /etc/ssh/ssh_host_key # HostKeys for protocol version 2 #HostKey /etc/ssh/ssh_host_rsa_key #HostKey /etc/ssh/ssh_host_dsa_key # Lifetime and size of ephemeral version 1 server key #KeyRegenerationInterval 1h #ServerKeyBits 1024 # Logging # obsoletes QuietMode and FascistLogging #SyslogFacility AUTH SyslogFacility AUTHPRIV #LogLevel INFO # Authentication: #LoginGraceTime 2m #PermitRootLogin yes #StrictModes yes #MaxAuthTries 6 #MaxSessions 10 #RSAAuthentication yes #PubkeyAuthentication yes #AuthorizedKeysFile .ssh/authorized_keys #AuthorizedKeysCommand none #AuthorizedKeysCommandRunAs nobody # For this to work you will also need host keys in /etc/ssh/ssh_known_hosts #RhostsRSAAuthentication no # similar for protocol version 2 #HostbasedAuthentication no # Change to yes if you don't trust ~/.ssh/known_hosts for # RhostsRSAAuthentication and HostbasedAuthentication #IgnoreUserKnownHosts no # Don't read the user's ~/.rhosts and ~/.shosts files #IgnoreRhosts yes # To disable tunneled clear text passwords, change to no here! #PasswordAuthentication yes #PermitEmptyPasswords no PasswordAuthentication yes # Change to no to disable s/key passwords #ChallengeResponseAuthentication yes ChallengeResponseAuthentication no # Kerberos options #KerberosAuthentication no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes #KerberosGetAFSToken no #KerberosUseKuserok yes # GSSAPI options #GSSAPIAuthentication no GSSAPIAuthentication yes #GSSAPICleanupCredentials yes GSSAPICleanupCredentials yes #GSSAPIStrictAcceptorCheck yes #GSSAPIKeyExchange no # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. #UsePAM no UsePAM yes # Accept locale-related environment variables AcceptEnv LANG LC_CTYPE LC_NUMERIC LC_TIME LC_COLLATE LC_MONETARY LC_MESSAGES AcceptEnv LC_PAPER LC_NAME LC_ADDRESS LC_TELEPHONE LC_MEASUREMENT AcceptEnv LC_IDENTIFICATION LC_ALL LANGUAGE AcceptEnv XMODIFIERS #AllowAgentForwarding yes #AllowTcpForwarding yes #GatewayPorts no #X11Forwarding no X11Forwarding yes #X11DisplayOffset 10 #X11UseLocalhost yes #PrintMotd yes #PrintLastLog yes #TCPKeepAlive yes #UseLogin no #UsePrivilegeSeparation yes #PermitUserEnvironment no #Compression delayed ClientAliveInterval 180 #ClientAliveCountMax 3 #ShowPatchLevel no #UseDNS yes #PidFile /var/run/sshd.pid #MaxStartups 10:30:100 #PermitTunnel no #ChrootDirectory none # no default banner path #Banner none # override default of no subsystems Subsystem sftp /usr/libexec/openssh/sftp-server # Example of overriding settings on a per-user basis #Match User anoncvs # X11Forwarding no # AllowTcpForwarding no # ForceCommand cvs server ================================================ FILE: VMAccess/resources/redhat_default ================================================ # $OpenBSD: sshd_config,v 1.80 2008/07/02 02:24:18 djm Exp $ # This is the sshd server system-wide configuration file. See # sshd_config(5) for more information. # This sshd was compiled with PATH=/usr/local/bin:/bin:/usr/bin # The strategy used for options in the default sshd_config shipped with # OpenSSH is to specify options with their default value where # possible, but leave them commented. Uncommented options change a # default value. #Port 22 #AddressFamily any #ListenAddress 0.0.0.0 #ListenAddress :: # Disable legacy (protocol version 1) support in the server for new # installations. In future the default will change to require explicit # activation of protocol 1 Protocol 2 # HostKey for protocol version 1 #HostKey /etc/ssh/ssh_host_key # HostKeys for protocol version 2 #HostKey /etc/ssh/ssh_host_rsa_key #HostKey /etc/ssh/ssh_host_dsa_key # Lifetime and size of ephemeral version 1 server key #KeyRegenerationInterval 1h #ServerKeyBits 1024 # Logging # obsoletes QuietMode and FascistLogging #SyslogFacility AUTH SyslogFacility AUTHPRIV #LogLevel INFO # Authentication: #LoginGraceTime 2m #PermitRootLogin yes #StrictModes yes #MaxAuthTries 6 #MaxSessions 10 #RSAAuthentication yes #PubkeyAuthentication yes #AuthorizedKeysFile .ssh/authorized_keys #AuthorizedKeysCommand none #AuthorizedKeysCommandRunAs nobody # For this to work you will also need host keys in /etc/ssh/ssh_known_hosts #RhostsRSAAuthentication no # similar for protocol version 2 #HostbasedAuthentication no # Change to yes if you don't trust ~/.ssh/known_hosts for # RhostsRSAAuthentication and HostbasedAuthentication #IgnoreUserKnownHosts no # Don't read the user's ~/.rhosts and ~/.shosts files #IgnoreRhosts yes # To disable tunneled clear text passwords, change to no here! #PasswordAuthentication yes #PermitEmptyPasswords no PasswordAuthentication yes # Change to no to disable s/key passwords #ChallengeResponseAuthentication yes ChallengeResponseAuthentication no # Kerberos options #KerberosAuthentication no #KerberosOrLocalPasswd yes #KerberosTicketCleanup yes #KerberosGetAFSToken no #KerberosUseKuserok yes # GSSAPI options #GSSAPIAuthentication no GSSAPIAuthentication yes #GSSAPICleanupCredentials yes GSSAPICleanupCredentials yes #GSSAPIStrictAcceptorCheck yes #GSSAPIKeyExchange no # Set this to 'yes' to enable PAM authentication, account processing, # and session processing. If this is enabled, PAM authentication will # be allowed through the ChallengeResponseAuthentication and # PasswordAuthentication. Depending on your PAM configuration, # PAM authentication via ChallengeResponseAuthentication may bypass # the setting of "PermitRootLogin without-password". # If you just want the PAM account and session checks to run without # PAM authentication, then enable this but set PasswordAuthentication # and ChallengeResponseAuthentication to 'no'. #UsePAM no UsePAM yes # Accept locale-related environment variables AcceptEnv LANG LC_CTYPE LC_NUMERIC LC_TIME LC_COLLATE LC_MONETARY LC_MESSAGES AcceptEnv LC_PAPER LC_NAME LC_ADDRESS LC_TELEPHONE LC_MEASUREMENT AcceptEnv LC_IDENTIFICATION LC_ALL LANGUAGE AcceptEnv XMODIFIERS #AllowAgentForwarding yes #AllowTcpForwarding yes #GatewayPorts no #X11Forwarding no X11Forwarding yes #X11DisplayOffset 10 #X11UseLocalhost yes #PrintMotd yes #PrintLastLog yes #TCPKeepAlive yes #UseLogin no #UsePrivilegeSeparation yes #PermitUserEnvironment no #Compression delayed ClientAliveInterval 180 #ClientAliveCountMax 3 #ShowPatchLevel no #UseDNS yes #PidFile /var/run/sshd.pid #MaxStartups 10:30:100 #PermitTunnel no #ChrootDirectory none # no default banner path #Banner none # override default of no subsystems Subsystem sftp /usr/libexec/openssh/sftp-server # Example of overriding settings on a per-user basis #Match User anoncvs # X11Forwarding no # AllowTcpForwarding no # ForceCommand cvs server ================================================ FILE: VMAccess/test/env.py ================================================ #!/usr/bin/env python # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os #append installer directory to sys.path root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root) manifestFile = os.path.join(root, 'HandlerManifest.json') if os.path.exists(manifestFile): import json jsonData = open(manifestFile) manifest = json.load(jsonData) jsonData.close() extName="{0}-{1}".format("VMAccess", manifest[0]["version"]) print("Start test: %s" % extName) extDir=os.path.join("/var/lib/waagent", extName) if(os.path.isdir(extDir)): os.chdir(extDir) print("Switching to dir: %s" % os.getcwd()) ================================================ FILE: VMAccess/test/test_iptable_rules.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from VMAccess.test import env from VMAccess import vmaccess import os from Utils.WAAgentUtil import waagent waagent.LoggerInit('/tmp/test.log','/dev/null') class TestIPhablesRule(unittest.TestCase): def test_insert_rule_if_not_exists(self): rule = 'INPUT -p tcp -m tcp --dport 9998 -j DROP' vmaccess._insert_rule_if_not_exists(rule) cmd_result = waagent.RunGetOutput("iptables-save | grep '%s'" %rule) self.assertEqual(cmd_result[0], 0) waagent.Run("iptables -D %s" %rule) def test_del_rule_if_exists(self): rule = 'INPUT -p tcp -m tcp --dport 9998 -j DROP' waagent.Run("iptables -I %s" %rule) vmaccess._del_rule_if_exists(rule) cmd_result = waagent.RunGetOutput("iptables-save | grep '%s'" %rule) self.assertNotEqual(cmd_result[0], 0) if __name__ == '__main__': unittest.main() ================================================ FILE: VMAccess/test/test_reset_account.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest from VMAccess import vmaccess from Utils.WAAgentUtil import waagent waagent.LoggerInit('/tmp/test.log','/dev/stdout') waagent.MyDistro = waagent.GetMyDistro() class Dummy(object): pass hutil = Dummy() hutil.log = waagent.Log class TestCreateNewAccount(unittest.TestCase): def test_creat_newuser(self): settings={} settings['username'] = 'NewUser' settings['password'] = 'User@123' waagent.Run('userdel %s' %settings['username']) vmaccess._set_user_account_pub_key(settings, hutil) waagent.Run("echo 'exit' > /tmp/exit.sh") cmd_result = waagent.RunGetOutput("sshpass -p 'User@123' ssh -o StrictHostKeyChecking=no" + " %s@localhost < /tmp/exit.sh" %settings['username']) self.assertEqual(cmd_result[0], 0) waagent.Run("rm exit.sh -f") waagent.Run('userdel %s' %settings['username']) expected_cert_str = """\ -----BEGIN CERTIFICATE----- MIICOTCCAaICCQD7F0nb+GtpcTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJh YjELMAkGA1UECAwCYWIxCzAJBgNVBAcMAmFiMQswCQYDVQQKDAJhYjELMAkGA1UE CwwCYWIxCzAJBgNVBAMMAmFiMREwDwYJKoZIhvcNAQkBFgJhYjAeFw0xNDA4MDUw ODIwNDZaFw0xNTA4MDUwODIwNDZaMGExCzAJBgNVBAYTAmFiMQswCQYDVQQIDAJh YjELMAkGA1UEBwwCYWIxCzAJBgNVBAoMAmFiMQswCQYDVQQLDAJhYjELMAkGA1UE AwwCYWIxETAPBgkqhkiG9w0BCQEWAmFiMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB iQKBgQC4Vugyj4uAKGYHW/D1eAg1DmLAv01e+9I0zIi8HzJxP87MXmS8EdG5SEzR N6tfQQie76JBSTYI4ngTaVCKx5dVT93LiWxLV193Q3vs/HtwwH1fLq0rAKUhREQ6 +CsRGNyeVfJkNsxAvNvQkectnYuOtcDxX5n/25eWAofobxVbSQIDAQABMA0GCSqG SIb3DQEBCwUAA4GBAF20gkq/DeUSXkZA+jjmmbCPioB3KL63GpoTXfP65d6yU4xZ TlMoLkqGKe3WoXmhjaTOssulgDAGA24IeWy/u7luH+oHdZEmEufFhj4M7tQ1pAhN CT8JCL2dI3F76HD6ZutTOkwRar3PYk5q7RsSJdAemtnwVpgp+RBMtbmct7MQ -----END CERTIFICATE----- """ class TestSaveCertFile(unittest.TestCase): def test_save_cert_Str_as_file(self): cert_str = waagent.GetFileContents(os.path.join(waagent.LibDir, 'TEST.crt')) vmaccess._save_cert_str_as_file(cert_str, '/tmp/tmp.crt') saved_cert_str = waagent.GetFileContents('/tmp/tmp.crt') self.assertEqual(saved_cert_str, expected_cert_str) class TestResetSshKey(unittest.TestCase): def test_reset_ssh_key(self): settings={} settings['username'] = 'NewUser' settings['ssh_key'] = waagent.GetFileContents(os.path.join(waagent.LibDir, 'TEST.crt')) vmaccess._set_user_account_pub_key(settings, hutil) waagent.Run("echo 'exit' > /tmp/exit.sh") cmd_result = waagent.RunGetOutput("ssh -o StrictHostKeyChecking=no -i %s" %os.path.join(waagent.LibDir, 'TEST.prv') + " %s@localhost < /tmp/exit.sh" %settings['username']) self.assertEqual(cmd_result[0], 0) waagent.Run("rm exit.sh -f") waagent.Run('userdel %s' %settings['username']) class TestResetExistingUser(unittest.TestCase): def test_reset_existing_user(self): settings={} settings['username'] = 'ExistingUser' settings['password'] = 'User@123' waagent.Run('userdel %s' %settings['username']) waagent.Run('useradd %s' %settings['username']) waagent.MyDistro.changePass(settings['username'], "Quattro!") vmaccess._set_user_account_pub_key(settings, hutil) waagent.Run("echo 'exit' > /tmp/exit.sh") cmd_result = waagent.RunGetOutput("sshpass -p 'User@123' ssh -o StrictHostKeyChecking=no" + " %s@localhost < /tmp/exit.sh" %settings['username']) self.assertEqual(cmd_result[0], 0) waagent.Run("rm exit.sh -f") waagent.Run('userdel %s' %settings['username']) if __name__ == '__main__': unittest.main() ================================================ FILE: VMAccess/test/test_reset_sshd_config.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from VMAccess.test import env from VMAccess import vmaccess import os from Utils.WAAgentUtil import waagent import shutil waagent.LoggerInit('/tmp/test.log','/dev/stdout') waagent.MyDistro = waagent.GetMyDistro() class Dummy(object): pass hutil = Dummy() hutil.log = waagent.Log class TestResetSshdConfig(unittest.TestCase): def test_reset_sshd_config(self): path = '/tmp/sshd_config' resources=os.path.join(env.root, 'resources') if(os.path.exists(path)): os.remove(path) if(os.path.isdir('resources')): shutil.rmtree('resources') shutil.copytree(resources, 'resources') vmaccess._reset_sshd_config(path) self.assertTrue(os.path.exists(path)) config = waagent.GetFileContents(path) self.assertFalse(config.startswith("#Default sshd_config")) os.remove(path) def test_backup_sshd_config(self): test_dir = '/tmp/test_vmaccess' path = os.path.join(test_dir, "old_sshd_config") if(not os.path.isdir(test_dir)): os.mkdir(test_dir) if(not os.path.exists(path)): waagent.Run("echo > %s" %path) vmaccess._backup_sshd_config(path) os.remove(path) files = os.listdir(test_dir) self.assertNotEqual(len(files), 0) shutil.rmtree(test_dir) if __name__ == '__main__': unittest.main() ================================================ FILE: VMAccess/vmaccess.py ================================================ #!/usr/bin/env python # # VMAccess extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import shutil import sys import tempfile import time import traceback import Utils.handlerutil2 as handler_util import Utils.logger as logger import Utils.extensionutils as ext_utils import Utils.distroutils as dist_utils import Utils.constants as constants import Utils.ovfutils as ovf_utils # Define global variables ExtensionShortName = 'VMAccess' BeginCertificateTag = '-----BEGIN CERTIFICATE-----' EndCertificateTag = '-----END CERTIFICATE-----' BeginSSHTag = '---- BEGIN SSH2 PUBLIC KEY ----' OutputSplitter = ';' SshdConfigPath = '/etc/ssh/sshd_config' SshdConfigBackupPath = '/var/cache/vmaccess/backup' # overwrite the default logger logger.global_shared_context_logger = logger.Logger('/var/log/waagent.log', '/dev/stdout') def get_os_name(): if os.path.isfile(constants.os_release): return ext_utils.get_line_starting_with("NAME", constants.os_release) elif os.path.isfile(constants.system_release): return ext_utils.get_file_contents(constants.system_release) return None def get_linux_agent_conf_filename(os_name): if os_name is not None: if re.search("coreos", os_name, re.IGNORECASE) or re.search("flatcar", os_name, re.IGNORECASE): return "/usr/share/oem/waagent.conf" return "/etc/waagent.conf" class ConfigurationProvider(object): """ Parse amd store key:values in waagent.conf """ def __init__(self, wala_config_file): self.values = dict() if not os.path.isfile(wala_config_file): logger.warning("Missing configuration in {0}, setting default values for PasswordCryptId and PasswordCryptSaltLength".format(wala_config_file)) self.values["Provisioning.PasswordCryptId"] = "6" self.values["Provisioning.PasswordCryptSaltLength"] = 10 return try: for line in ext_utils.get_file_contents(wala_config_file).split('\n'): if not line.startswith("#") and "=" in line: parts = line.split()[0].split('=') value = parts[1].strip("\" ") if value != "None": self.values[parts[0]] = value else: self.values[parts[0]] = None # when get_file_contents returns none except AttributeError: logger.error("Unable to parse {0}".format(wala_config_file)) raise return def get(self, key): return self.values.get(key) def yes(self, key): config_value = self.get(key) if config_value is not None and config_value.lower().startswith("y"): return True else: return False def no(self, key): config_value = self.get(key) if config_value is not None and config_value.lower().startswith("n"): return True else: return False OSName = get_os_name() Configuration = ConfigurationProvider(get_linux_agent_conf_filename(OSName)) MyDistro = dist_utils.get_my_distro(Configuration, OSName) def main(): logger.log("%s started to handle." % ExtensionShortName) try: for a in sys.argv[1:]: if re.match("^([-/]*)(enable)", a): enable() except Exception as e: err_msg = "Failed with error: {0}, {1}".format(e, traceback.format_exc()) logger.error(err_msg) def enable(): hutil = handler_util.HandlerUtility() hutil.do_parse_context('Enable') try: hutil.exit_if_enabled(remove_protected_settings=True) # If no new seqNum received, exit. reset_ssh = None remove_user = None restore_backup_ssh = None protect_settings = hutil.get_protected_settings() if protect_settings: reset_ssh = protect_settings.get('reset_ssh', False) remove_user = protect_settings.get('remove_user') restore_backup_ssh = protect_settings.get('restore_backup_ssh', False) if remove_user and _is_sshd_config_modified(protect_settings): ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=False, message="(03002)Argument error, conflicting operations") raise Exception("Cannot reset sshd_config and remove a user in one operation.") _forcibly_reset_chap(hutil) if reset_ssh or restore_backup_ssh: _open_ssh_port() hutil.log("Succeeded in check and open ssh port.") ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="reset-ssh") _reset_sshd_config(hutil, restore_backup_ssh) hutil.log("Succeeded in {0} sshd_config.".format("resetting" if reset_ssh else "restoring")) if remove_user: ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="remove-user") _remove_user_account(remove_user, hutil) _set_user_account_pub_key(protect_settings, hutil) if _is_sshd_config_modified(protect_settings): MyDistro.restart_ssh_service() check_and_repair_disk(hutil) hutil.do_exit(0, 'Enable', 'success', '0', 'Enable succeeded.') except Exception as e: hutil.error(("Failed to enable the extension with error: {0}, " "stack trace: {1}").format(str(e), traceback.format_exc())) hutil.do_exit(1, 'Enable', 'error', '0', "Enable failed: {0}".format(str(e))) def _forcibly_reset_chap(hutil): name = "ChallengeResponseAuthentication" _backup_and_update_sshd_config(hutil, name, "no") MyDistro.restart_ssh_service() def _is_sshd_config_modified(protected_settings): result = protected_settings.get('reset_ssh') or protected_settings.get('restore_backup_ssh') or protected_settings.get('password') return result is not None def _remove_user_account(user_name, hutil): hutil.log("Removing user account") try: sudoers = _get_other_sudoers(user_name) MyDistro.delete_account(user_name) _save_other_sudoers(sudoers) except Exception as e: ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=False, message="(02102)Failed to remove user.") raise Exception("Failed to remove user {0}".format(e)) ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=True, message="Successfully removed user") def _set_user_account_pub_key(protect_settings, hutil): ovf_env = None try: ovf_xml = ext_utils.get_file_contents('/var/lib/waagent/ovf-env.xml') if ovf_xml is not None: ovf_env = ovf_utils.OvfEnv.parse(ovf_xml, Configuration, False, False) except (EnvironmentError, ValueError, KeyError, AttributeError, TypeError): pass if ovf_env is None: # default ovf_env with empty data ovf_env = ovf_utils.OvfEnv() logger.log("could not load ovf-env.xml") # user name must be provided if set ssh key or password if not protect_settings or 'username' not in protect_settings: return user_name = protect_settings['username'] user_pass = protect_settings.get('password') cert_txt = protect_settings.get('ssh_key') expiration = protect_settings.get('expiration') remove_prior_keys = protect_settings.get('remove_prior_keys') enable_passwordless_access = protect_settings.get('enable_passwordless_access', False) no_convert = False if not user_pass and not cert_txt and not ovf_env.SshPublicKeys: raise Exception("No password or ssh_key is specified.") if user_pass is not None and len(user_pass) == 0: user_pass = None hutil.log("empty passwords are not allowed, ignoring password reset") # Reset user account and password, password could be empty sudoers = _get_other_sudoers(user_name) error_string = MyDistro.create_account( user_name, user_pass, expiration, None, enable_passwordless_access) _save_other_sudoers(sudoers) if error_string is not None: err_msg = "Failed to create the account or set the password" ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=False, message="(02101)" + err_msg) raise Exception(err_msg + " with " + error_string) hutil.log("Succeeded in creating the account or setting the password.") # Allow password authentication if user_pass is provided if user_pass is not None: ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="create-user-with-password") _allow_password_auth(hutil) # Reset ssh key with the new public key passed in or reuse old public key. if cert_txt: # support for SSH2-compatible format for public keys in addition to OpenSSH-compatible format if cert_txt.strip().startswith(BeginSSHTag): ext_utils.set_file_contents("temp.pub", cert_txt.strip()) retcode, output = ext_utils.run_command_get_output(['ssh-keygen', '-i', '-f', 'temp.pub']) if retcode > 0: raise Exception("Failed to convert SSH2 key to OpenSSH key.") hutil.log("Succeeded in converting SSH2 key to OpenSSH key.") cert_txt = output os.remove("temp.pub") if cert_txt.strip().lower().startswith("ssh-rsa") or cert_txt.strip().lower().startswith("ssh-ed25519"): no_convert = True try: pub_path = os.path.join('/home/', user_name, '.ssh', 'authorized_keys') ovf_env.UserName = user_name if no_convert: if cert_txt: pub_path = ovf_env.prepare_dir(pub_path, MyDistro) final_cert_txt = cert_txt if not cert_txt.endswith("\n"): final_cert_txt = final_cert_txt + "\n" if remove_prior_keys == True: ext_utils.set_file_contents(pub_path, final_cert_txt) hutil.log("Removed prior ssh keys and added new key for user %s" % user_name) else: ext_utils.append_file_contents(pub_path, final_cert_txt) MyDistro.set_se_linux_context( pub_path, 'unconfined_u:object_r:ssh_home_t:s0') ext_utils.change_owner(pub_path, user_name) ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="create-user") hutil.log("Succeeded in resetting ssh_key.") else: err_msg = "Failed to reset ssh key because the cert content is empty." ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=False, message="(02100)" + err_msg) else: # do the certificate conversion # we support PKCS8 certificates besides ssh-rsa public keys _save_cert_str_as_file(cert_txt, 'temp.crt') pub_path = ovf_env.prepare_dir(pub_path, MyDistro) retcode = ext_utils.run_command_and_write_stdout_to_file( [constants.Openssl, 'x509', '-in', 'temp.crt', '-noout', '-pubkey'], "temp.pub") if retcode > 0: raise Exception("Failed to generate public key file.") MyDistro.ssh_deploy_public_key('temp.pub', pub_path) os.remove('temp.pub') os.remove('temp.crt') ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="create-user") hutil.log("Succeeded in resetting ssh_key.") except Exception as e: hutil.log(str(e)) ext_utils.add_extension_event(name=hutil.get_name(), op=constants.WALAEventOperation.Enable, is_success=False, message="(02100)Failed to reset ssh key.") raise e def _get_other_sudoers(user_name): sudoers_file = '/etc/sudoers.d/waagent' if not os.path.isfile(sudoers_file): return None sudoers = ext_utils.get_file_contents(sudoers_file).split("\n") pattern = '^{0}\s'.format(user_name) sudoers = list(filter(lambda x: re.match(pattern, x) is None, sudoers)) return sudoers def _save_other_sudoers(sudoers): sudoers_file = '/etc/sudoers.d/waagent' if sudoers is None: return ext_utils.append_file_contents(sudoers_file, "\n".join(sudoers)) os.chmod("/etc/sudoers.d/waagent", 0o440) def _allow_password_auth(hutil): name = "PasswordAuthentication" _backup_and_update_sshd_config(hutil, name, "yes") cloudInitConfigPath = "/etc/ssh/sshd_config.d/50-cloud-init.conf" config = ext_utils.get_file_contents(cloudInitConfigPath) if config is not None: config = config.split("\n") _set_sshd_config(config, name, "yes") ext_utils.replace_file_with_contents_atomic(cloudInitConfigPath, "\n".join(config)) def _backup_and_update_sshd_config(hutil, attr_name, attr_value): config = ext_utils.get_file_contents(SshdConfigPath).split("\n") for i in range(0, len(config)): if config[i].startswith(attr_name) and attr_value in config[i].lower(): hutil.log("%s already set to %s in sshd_config, skip update." % (attr_name, attr_value)) return hutil.log("Setting %s to %s in sshd_config." % (attr_name, attr_value)) _backup_sshd_config(hutil) _set_sshd_config(config, attr_name, attr_value) ext_utils.replace_file_with_contents_atomic(SshdConfigPath, "\n".join(config)) def _set_sshd_config(config, name, val): notfound = True i = None for i in range(0, len(config)): if config[i].startswith(name): config[i] = "{0} {1}".format(name, val) notfound = False elif config[i].startswith("Match"): # Match block must be put in the end of sshd config break if notfound: if i is None: i = 0 config.insert(i, "{0} {1}".format(name, val)) return config def _get_default_ssh_config_filename(): if OSName is not None: # the default ssh config files are present in # /var/lib/waagent/Microsoft.OSTCExtensions.VMAccessForLinux-/resources/ if re.search("centos", OSName, re.IGNORECASE): return "centos_default" if re.search("debian", OSName, re.IGNORECASE): return "debian_default" if re.search("fedora", OSName, re.IGNORECASE): return "fedora_default" if re.search("red\s?hat", OSName, re.IGNORECASE): return "redhat_default" if re.search("suse", OSName, re.IGNORECASE): return "SuSE_default" if re.search("ubuntu", OSName, re.IGNORECASE): return "ubuntu_default" return "default" def _reset_sshd_config(hutil, restore_backup_ssh): ssh_default_config_filename = _get_default_ssh_config_filename() ssh_default_config_file_path = os.path.join(os.getcwd(), 'resources', ssh_default_config_filename) if not os.path.exists(ssh_default_config_file_path): ssh_default_config_file_path = os.path.join(os.getcwd(), 'resources', 'default') if restore_backup_ssh: if os.path.exists(SshdConfigBackupPath): ssh_default_config_file_path = SshdConfigBackupPath # handle CoreOS differently if isinstance(MyDistro, dist_utils.CoreOSDistro): # Parse sshd port from ssh_default_config_file_path sshd_port = 22 regex = re.compile(r"^Port\s+(\d+)", re.VERBOSE) with open(ssh_default_config_file_path) as f: for line in f: match = regex.match(line) if match: sshd_port = match.group(1) break # Prepare cloud init config for coreos-cloudinit f = tempfile.NamedTemporaryFile(delete=False) f.close() cfg_tempfile = f.name cfg_content = "#cloud-config\n\n" # Overwrite /etc/ssh/sshd_config cfg_content += "write_files:\n" cfg_content += " - path: {0}\n".format(SshdConfigPath) cfg_content += " permissions: 0600\n" cfg_content += " owner: root:root\n" cfg_content += " content: |\n" for line in ext_utils.get_file_contents(ssh_default_config_file_path).split('\n'): cfg_content += " {0}\n".format(line) # Change the sshd port in /etc/systemd/system/sshd.socket cfg_content += "\ncoreos:\n" cfg_content += " units:\n" cfg_content += " - name: sshd.socket\n" cfg_content += " command: restart\n" cfg_content += " content: |\n" cfg_content += " [Socket]\n" cfg_content += " ListenStream={0}\n".format(sshd_port) cfg_content += " Accept=yes\n" ext_utils.set_file_contents(cfg_tempfile, cfg_content) ext_utils.run(['coreos-cloudinit', '-from-file', cfg_tempfile], chk_err=False) os.remove(cfg_tempfile) else: shutil.copyfile(ssh_default_config_file_path, SshdConfigPath) if ssh_default_config_file_path == SshdConfigBackupPath: hutil.log("sshd_config restored from backup, remove backup file.") # Remove backup config once sshd_config restored os.remove(ssh_default_config_file_path) MyDistro.restart_ssh_service() def _backup_sshd_config(hutil): if os.path.exists(SshdConfigPath) and not os.path.exists(SshdConfigBackupPath): # Create VMAccess cache folder if doesn't exist if not os.path.exists(os.path.dirname(SshdConfigBackupPath)): os.makedirs(os.path.dirname(SshdConfigBackupPath)) hutil.log("Create backup ssh config file") open(SshdConfigBackupPath, 'a').close() # When copying, make sure to preserve permissions and ownership. ownership = os.stat(SshdConfigPath) shutil.copy2(SshdConfigPath, SshdConfigBackupPath) os.chown(SshdConfigBackupPath, ownership.st_uid, ownership.st_gid) def _save_cert_str_as_file(cert_txt, file_name): cert_start = cert_txt.find(BeginCertificateTag) if cert_start >= 0: cert_txt = cert_txt[cert_start + len(BeginCertificateTag):] cert_end = cert_txt.find(EndCertificateTag) if cert_end >= 0: cert_txt = cert_txt[:cert_end] cert_txt = cert_txt.strip() cert_txt = "{0}\n{1}\n{2}\n".format(BeginCertificateTag, cert_txt, EndCertificateTag) ext_utils.set_file_contents(file_name, cert_txt) def _open_ssh_port(): _del_rule_if_exists(['INPUT', '-p', 'tcp', '-m', 'tcp', '--dport', '22', '-j', 'DROP']) _del_rule_if_exists(['INPUT', '-p', 'tcp', '-m', 'tcp', '--dport', '22', '-j', 'REJECT']) _del_rule_if_exists(['INPUT', '-p', '-j', 'DROP']) _del_rule_if_exists(['INPUT', '-p', '-j', 'REJECT']) _insert_rule_if_not_exists(['INPUT', '-p', 'tcp', '-m', 'tcp', '--dport', '22', '-j', 'ACCEPT']) _del_rule_if_exists(['OUTPUT', '-p', 'tcp', '-m', 'tcp', '--sport', '22', '-j', 'DROP']) _del_rule_if_exists(['OUTPUT', '-p', 'tcp', '-m', 'tcp', '--sport', '22', '-j', 'REJECT']) _del_rule_if_exists(['OUTPUT', '-p', '-j', 'DROP']) _del_rule_if_exists(['OUTPUT', '-p', '-j', 'REJECT']) _insert_rule_if_not_exists(['OUTPUT', '-p', 'tcp', '-m', 'tcp', '--dport', '22', '-j', 'ACCEPT']) def _del_rule_if_exists(rule_string): rule_string_for_cmp = " ".join(rule_string) cmd_result = ext_utils.run_command_get_output(['iptables-save']) while cmd_result[0] == 0 and (rule_string_for_cmp in cmd_result[1]): ext_utils.run(['iptables', '-D'] + rule_string) cmd_result = ext_utils.run_command_get_output(['iptables-save']) def _insert_rule_if_not_exists(rule_string): rule_string_for_cmp = " ".join(rule_string) cmd_result = ext_utils.run_command_get_output(['iptables-save']) if cmd_result[0] == 0 and (rule_string_for_cmp not in cmd_result[1]): ext_utils.run_command_get_output(['iptables', '-I'] + rule_string) def check_and_repair_disk(hutil): public_settings = hutil.get_public_settings() if public_settings: check_disk = public_settings.get('check_disk') repair_disk = public_settings.get('repair_disk') disk_name = public_settings.get('disk_name') if check_disk and repair_disk: err_msg = ("check_disk and repair_disk was both specified." "Only one of them can be specified") hutil.error(err_msg) hutil.do_exit(1, 'Enable', 'error', '0', 'Enable failed.') if check_disk: ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="check_disk") outretcode = _fsck_check(hutil) hutil.log("Successfully checked disk") return outretcode if repair_disk: ext_utils.add_extension_event(name=hutil.get_name(), op="scenario", is_success=True, message="repair_disk") outdata = _fsck_repair(hutil, disk_name) hutil.log("Repaired and remounted disk") return outdata def _fsck_check(hutil): try: retcode = ext_utils.run(['fsck', '-As', '-y']) if retcode > 0: hutil.log(retcode) raise Exception("Disk check was not successful") else: return retcode except Exception as e: hutil.error("Failed to run disk check with error: {0}, {1}".format( str(e), traceback.format_exc())) hutil.do_exit(1, 'Check', 'error', '0', 'Check failed.') def _fsck_repair(hutil, disk_name): # first unmount disks and loop devices lazy + forced try: cmd_result = ext_utils.run(['umount', '-f', '/' + disk_name]) if cmd_result != 0: # Fail fast hutil.log("Failed to unmount disk: %s" % disk_name) # run repair retcode = ext_utils.run(['fsck', '-AR', '-y']) hutil.log("Ran fsck with return code: %d" % retcode) if retcode == 0: retcode, output = ext_utils.run_command_get_output(["mount"]) hutil.log(output) return output else: raise Exception("Failed to mount disks") except Exception as e: hutil.error("{0}, {1}".format(str(e), traceback.format_exc())) hutil.do_exit(1, 'Repair', 'error', '0', 'Repair failed.') if __name__ == '__main__': main() ================================================ FILE: VMBackup/.gitignore ================================================ # VMBackup debughelper builds debughelper/msft_snap_monit ================================================ FILE: VMBackup/HandlerManifest.json ================================================ [ { "handlerManifest": { "disableCommand": "main/handle.sh disable", "enableCommand": "main/handle.sh enable", "installCommand": "main/handle.sh install", "rebootAfterInstall": false, "reportHeartbeat": false, "uninstallCommand": "main/handle.sh uninstall", "updateCommand": "main/handle.sh update" }, "name": "MyBackupTestLinuxInt", "version": "1.0.9120.0" } ] ================================================ FILE: VMBackup/MANIFEST.in ================================================ include HandlerManifest.json handler.py prune test ================================================ FILE: VMBackup/README.txt ================================================ VMBackup extension is used by Azure Backup service to provide application consistent backup for Linux VMs running in Azure. **Note:** This extension is not recommended to be installed outside Azure Backup service context. ## Deploying the extension to a VM This extension gets deployed as part of first scheduled backup of the VM post you configure VM for backup. You can configure VM to be backed up using [Azure Portal](https://docs.microsoft.com/azure/backup/quick-backup-vm-portal), [Azure PowerShell](https://docs.microsoft.com/azure/backup/quick-backup-vm-powershell) or Azure CLI(https://docs.microsoft.com/azure/backup/quick-backup-vm-cli). ================================================ FILE: VMBackup/VMBackup.pyproj ================================================  Debug 2.0 {a09c7cdb-874f-4214-bab2-90f888eac208} . test\handle.py . . VMBackup VMBackup true false true false 10.0 $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets ================================================ FILE: VMBackup/debughelper/README.md ================================================ # Diagnostic app for Snapshot Extensions ## What? This is a very experimental stage program to capture system level logs and metrics while a snapshot operation is in progress. This kind of data is often private to each customer and hence we have no plans right now of including this as part of the normal workflow. This tool will hopefully help us gather critical data to debug issues that are not solved with the normal logs we collect. ## Build - Please install [Go](https://go.dev/doc/install) and then - Copy this directory somewhere. Let's say: `/var/log/azure/Microsoft.Azure.RecoveryServices.VMSnapshotLinux/debughelper` - Change directory and build ```sh sudo su cd /var/log/azure/Microsoft.Azure.RecoveryServices.VMSnapshotLinux/debughelper # Make sure location of go binary is in PATH # If go was extracted to /usr/local/go - then run # # export PATH="/usr/local/go/bin:$PATH" # # put the export statement above in your ~/.profile or ~/.bashrc # or whatever`rc` file depending on your shell to persist across shell # restarts # # If you used your package manager to install `go` then # you, most likely, don't need to do any of this. go build ls -la ``` - On listing files you should see an entry called `msft_snap_monit` - That's it! ## Run as part of a snapshot operation - Create or edit "/etc/azure/vmbackup.conf" - Add a new section `[Monitor]` to the file and - under that add the following options - `Run=yes` - `Strace=no` - `Location=/var/log/azure/Microsoft.Azure.RecoveryServices.VMSnapshotLinux/debughelper` - Of course if you want strace to also run while taking a snapshot enable the option - Your `/etc/azure/vmback.conf` file should look something like this after these changes ```text [Monitor] Run=yes Strace=no Location=/var/log/azure/Microsoft.Azure.RecoveryServices.VMSnapshotLinux/debughelper ``` - I've not included other sections that the file might already include so donot delete if there are any. - That's it! Now every time after a snapshot/restore point is taken a new folder will be created at `/var/log/azure/Microsoft.Azure.RecoveryServices.VMSnapshotLinux/debughelper/`. - The name of this new folder will be a `ULID` which can be sorted by time - Inside this directory you'll see several files called `cpu.log`, `mem.log`, `disk.log`, `strace.log` - It is pretty self explanatory what metrics/logs each file contains. I'll add a more detailed section about how to read the data in each file soon. For the curious, it's pretty much the data in the `/proc/*` files. ## Running manually When run manually, right now, it will keep running till it receives an OS Interrupt (Ctrl+c) after the binary has been executed. The default behavior (do `./msft_snap_monit --help` for all options) will log everything to a shared memory location (`/dev/shm/Microsoft.Azure.Snapshots.Diagnostics/`) and after it has been interrupted will move the log subdirectory (see section below) to the working directory - which by default is the current directory. ```sh ./msft_snap_monit ``` You should see a lot of logs like: ``` 2023/10/19 12:05:34 [monitorCPU] -> sending new metric 2023/10/19 12:05:34 [logMem] -> received new metric 2023/10/19 12:05:34 [logMem] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:34 [logDisk] -> received new metric 2023/10/19 12:05:34 [logDisk] -> writing to log file 2023/10/19 12:05:35 [monitorCPU] -> sending new metric 2023/10/19 12:05:35 [logMem] -> received new metric 2023/10/19 12:05:35 [logMem] -> writing to log file 2023/10/19 12:05:35 [logDisk] -> received new metric 2023/10/19 12:05:35 [logDisk] -> writing to log file 2023/10/19 12:05:35 [logDisk] -> received new metric ... and so on ``` Ignore this and in another terminal window: ```sh # go to the shared memory location. this directory is in memory so fsfreeze will # not affect it cd /dev/shm/Microsoft.Azure.Snapshots.Diagnostics ls -l ``` You should see a subdirectory here that looks something like `01H7J4WD653PA49Y2X3J1RVYHS`. This is a ULID (see the ULID section below). `cd` into it and list files. ```sh cd 01H7J4WD653PA49Y2X3J1RVYHS ls ``` Now you should see some `.log` files here. `tail` them to see data as its written: ```sh # tail the cpu file tail -f cpu.log # or tail all logs files tail -f *.log ``` ## ULID Each run will generate a fresh [ULID](https://github.com/ulid/spec). This ID is unique to this run and all associated logs will be stored in a subdirectory inside the working directory with the ID as it's name. ULID has the nice property of encoding the Unix timestamp in the generated ID - so it will be easy later to make corelations based on time. ```sh go install github.com/oklog/ulid/v2/cmd/ulid@latest # Let's assume we have a ULID: 01H7J38F44J44RZ5CYYJHKMVHB ulid 01H7J38F44J44RZ5CYYJHKMVHB ``` The output should be ```sh Fri Aug 11 10:46:15.94 UTC 2023 ``` ### NB: Running this with strace enabled hasn't been completely tested yet - give it a whirl if you want. Ofcourse please make sure strace is installed. You will need the PID of a running process. In one terminal run: ```sh watch ls -la /tmp ``` In another terminal run ```sh ps -ef | grep watch | grep -v grep | awk '{print $2}' ``` Let's say the process ID is: `35151` ```sh ./msft_snap_monit --strace --tracepid 35151 ``` ## Plan There are quite a few more resources to log and monitor like processes and known applications that conflict with snapshots like antiviruses or network monitors but the broader structure of the code should not need too many changes. Please test it out and open issues for bugs and feature requests that you think would help in debugging snapshots. ### Thank You ================================================ FILE: VMBackup/debughelper/checkMounts.go ================================================ package main type Mount struct{} // Get all mounts - to check for noatime vs reltime func mounts() []Mount { return nil } ================================================ FILE: VMBackup/debughelper/go.mod ================================================ module msft_snap_monit go 1.21.0 require github.com/oklog/ulid/v2 v2.1.0 ================================================ FILE: VMBackup/debughelper/go.sum ================================================ github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= ================================================ FILE: VMBackup/debughelper/main.go ================================================ package main import ( "bytes" "context" "flag" "fmt" "io" "log" "os" "os/exec" "os/signal" "runtime" "strings" "sync" "syscall" "time" "github.com/oklog/ulid/v2" ) var ( working_directory = flag.String( "wd", "./", "Location which this application will use for all it's processing and persisting data. Please make sure this location does not get frozen during a snapshot operation", ) extension_command = flag.String("extcmd", "", "The command to execute extensions with") run_diagnosis = flag.Bool("diagnose", false, "Daignose the system") with_strace = flag.Bool("strace", false, "The tool will run with strace enabled") strace_pid = flag.Int64("tracepid", 0, "The PID to apply strace on") log_to_mem = flag.Bool("logtomem", true, "Will temporarily log to memory before moving all log files to working directory") ) func wrapErr(err error, msgs ...string) error { pc := make([]uintptr, 15) n := runtime.Callers(2, pc) frames := runtime.CallersFrames(pc[:n]) frame, _ := frames.Next() src := frame.Function s := strings.Join(append([]string{src}, msgs...), " -> ") return fmt.Errorf("%s -> %s", s, err.Error()) } func checkBinExistence(c string) bool { if len(c) == 0 { return false } cmd := exec.Command("which", c) bs, err := cmd.CombinedOutput() if err != nil { log.Println(wrapErr(err, "CombinedOutput failed")) return false } if cmd.ProcessState.ExitCode() != 0 { return false } if len(bs) == 0 { return false } return true } func checkSvcExistence(s string) bool { if len(s) == 0 { return false } cmd := exec.Command("systemctl", "list-unit-files", "--type", "service") cmd2 := exec.Command("grep", "-e", s) r, w := io.Pipe() cmd.Stdout = w cmd2.Stdin = r var b2 bytes.Buffer cmd2.Stdout = &b2 cmd.Start() cmd2.Start() cmd.Wait() w.Close() cmd2.Wait() bs := b2.Bytes() if len(bs) == 0 { return false } return true } func envVarExists(v string) bool { return len(os.Getenv(v)) > 0 } func databaseText(d string) string { return fmt.Sprintf("Unsupported database detected: \"%s\". Please make sure the database is not in use during a snapshot operation. The heavy disk IO behavior of databases can conflict with disk freezing", d) } func avText(a string) string { return fmt.Sprintf("Anitivirus detected: \"%s\". Make sure no files, directories, or mountpoints are being scanned during a snapshot operation", a) } func diagnoseDbs() []string { dbreport := []string{} if checkSvcExistence("postgresql.service") { dbreport = append(dbreport, databaseText("PostgreSQL")) } if checkSvcExistence("mongod") { dbreport = append(dbreport, databaseText("MongoDB")) } if checkBinExistence("mysqld") || checkBinExistence("mysql") { dbreport = append(dbreport, databaseText("MySQL")) } return dbreport } func diagnoseAvs() []string { clamAVExists := checkBinExistence("clamscan") bitDefenderExists := checkSvcExistence("bdsec*") avreport := []string{} if clamAVExists { avreport = append(avreport, avText("ClamAV")) } if bitDefenderExists { avreport = append(avreport, avText("Bitdefender")) } return avreport } func main() { flag.Parse() opID := fmt.Sprintf("%s", ulid.Make()) if *with_strace && *strace_pid == 0 { log.Printf("Cannot trace PID: 0") return } r, rf, err := NewRun(*working_directory, opID, *with_strace, *strace_pid, *log_to_mem) if err != nil { log.Println(wrapErr(err)) return } defer rf.Close() if *run_diagnosis { lf := r.diagnose() log.Printf("Diagnosis has been written to:\n%s\n", lf) return } wg := sync.WaitGroup{} ctx, cancel := context.WithCancel(context.Background()) wg.Add(1) go func() { defer wg.Done() r.monitor(ctx) }() inter := make(chan os.Signal, 1) // Auto kill after 20 minutes go func(inter chan os.Signal) { i := 0 ticker := time.NewTicker(time.Second) for range ticker.C { i++ if i >= (20 * 60) { break } } ticker.Stop() inter <- syscall.SIGINT }(inter) signal.Notify(inter, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-inter cancel() wg.Wait() } ================================================ FILE: VMBackup/debughelper/run.go ================================================ package main import ( "context" "encoding/json" "fmt" "log" "os" "os/exec" "path" "strconv" "strings" "sync" "time" ) const LF_RUN = "run.log" const LF_STRACE = "strace.log" const LF_CPU = "cpu.log" const LF_MEM = "mem.log" const LF_DISK = "disk.log" const LF_DIAG = "diagnosis.log" type Run struct { wd string opID string log *log.Logger strace bool tracePID int64 logToMem bool inMemDir string } func NewRun(workingDir, opID string, with_strace bool, trace_pid int64, logToMem bool) (*Run, *os.File, error) { r := &Run{ wd: workingDir, opID: opID, strace: with_strace, tracePID: trace_pid, inMemDir: "/dev/shm/Microsoft.Azure.Snapshots.Diagnostics", logToMem: logToMem, } if r.logToMem { if err := os.MkdirAll(path.Join(r.inMemDir, r.opID), 0755); err != nil { return nil, nil, wrapErr(err, "os.MkdirAll failed") } } f, err := os.OpenFile(path.Join(r.workDir(), LF_RUN), os.O_CREATE, 0644) if err != nil { return nil, nil, wrapErr(err, "os.OpenFile failed") } r.log = log.New(f, "", log.Ldate|log.Ltime|log.LUTC) return r, f, nil } func (r Run) workDir() string { p := path.Join(r.wd, r.opID) if r.logToMem { p = path.Join(r.inMemDir, r.opID) } return p } func (r Run) startStrace(ctx context.Context) error { if !r.strace { return nil } if r.tracePID == 0 { return fmt.Errorf("empty process ID") } command := exec.CommandContext( ctx, "strace", "-t", "-p", fmt.Sprintf("%d", r.tracePID), "-f", "-o", path.Join(r.workDir(), LF_STRACE), ) _, err := command.CombinedOutput() if err != nil { r.log.Println(wrapErr(err, "CombinedOutput failed")) } return nil } func (r Run) diagnose() string { avreport := diagnoseAvs() dbreport := diagnoseDbs() logFile := path.Join(r.workDir(), LF_DIAG) f, err := os.OpenFile(logFile, os.O_CREATE, 0644) if err != nil { r.log.Println(wrapErr(err, "os.OpenFile failed")) } defer f.Close() l := "" if len(avreport) > 0 { l = l + "========== ANTIVIRUS ============\n\n" l = l + strings.Join(avreport, "\n\n") } f.WriteString(l) if len(l) > 0 { l = "\n\n\n" } if len(dbreport) > 0 { l = l + "========== DATABSES ============\n\n" l = l + strings.Join(dbreport, "\n\n") } f.WriteString(l) f.WriteString("\n") r.persistInMemDir() return path.Join(r.wd, r.opID, LF_DIAG) } type LoadAvg struct { TS int64 `json:"timestamp_millis"` One string `json:"one"` Five string `json:"five"` Fifteen string `json:"fifteen"` SchedRatio string `json:"scheduled_ratio"` LP string `json:"last_pid"` } func (r Run) monitorCPU(ctx context.Context, cpuStream chan *LoadAvg) { // log.Println("[monitorCPU] -> Fired") ticker := time.NewTicker(time.Second) ctx1, cancel := context.WithCancel(ctx) outer: for { select { case <-ctx.Done(): cancel() ticker.Stop() cpuStream <- nil break outer case <-ticker.C: go func() { command := exec.CommandContext(ctx1, "cat", "/proc/loadavg") bs, err := command.CombinedOutput() if err != nil { r.log.Println(wrapErr(err, "CombinedOutput failed")) } else { fields := strings.Fields(strings.Trim(string(bs), " \n")) if len(fields) != 5 { r.log.Println(wrapErr(fmt.Errorf("/proc/loadavg returned invalid number of strings"))) } else { la := LoadAvg{ One: fields[0], Five: fields[1], Fifteen: fields[2], SchedRatio: fields[3], LP: fields[4], TS: time.Now().UnixMilli(), } log.Println("[monitorCPU] -> sending new metric") cpuStream <- &la } } }() } } } func (r Run) logCPU(ctx context.Context, cpuStream chan *LoadAvg) error { f, err := os.Create(path.Join(r.workDir(), LF_CPU)) if err != nil { return wrapErr(err, "os.Create failed") } // logger := log.New(f, "", log.Ldate|log.Ltime|log.LUTC) defer f.Close() outer: for { select { case <-ctx.Done(): break outer case lav := <-cpuStream: // log.Println("[logCPU] -> new metric received") bs, err := json.Marshal(lav) if err != nil { r.log.Println(wrapErr(err, "json.Marshal failed")) } else { // log.Println("[logCPU] -> writing to log file") f.WriteString(fmt.Sprintf("%s\n", string(bs))) } } } return nil } type Mem struct { TS int64 `json:"timestamp_millis"` TotalKb int64 `json:"total_kb"` AvailKb int64 `json:"avail_kb"` FreeKb int64 `json:"free_kb"` CachedKb int64 `json:"cached_kb"` SwapCachedKb int64 `json:"swap_cached_kb"` SwapTotalKb int64 `json:"swap_total_kb"` SwapFreeKb int64 `json:"swap_free_kb"` } func (r Run) monitorMem(ctx context.Context, memStream chan *Mem) { ticker := time.NewTicker(time.Second) ctx1, cancel := context.WithCancel(context.TODO()) outer: for { select { case <-ctx.Done(): cancel() ticker.Stop() memStream <- nil break outer case <-ticker.C: command := exec.CommandContext(ctx1, "cat", "/proc/meminfo") bs, err := command.CombinedOutput() if err != nil { log.Println(wrapErr(err, "CombinedOutput failed")) } else { m := Mem{ TS: time.Now().UnixMilli(), } flag := false for _, line := range strings.Split(string(bs), "\n") { if len(line) == 0 { continue } fields := strings.Fields(line) if len(fields) != 3 { continue } switch fields[0] { case "MemTotal:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("MemTotal conversion to int failed: ", err) } else { m.TotalKb = int64(v) flag = true } case "MemAvailable:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("MemAvailable conversion to int failed: ", err) } else { m.AvailKb = int64(v) flag = true } case "MemFree:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("MemFree conversion to int failed: ", err) } else { m.FreeKb = int64(v) flag = true } case "Cached:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("Cached Mem conversion to int failed: ", err) } else { m.CachedKb = int64(v) flag = true } case "SwapCached:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("SwapCached conversion to int failed: ", err) } else { m.SwapCachedKb = int64(v) flag = true } case "SwapTotal:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("SwapTotal conversion to int failed: ", err) } else { m.SwapTotalKb = int64(v) flag = true } case "SwapFree:": v, err := strconv.Atoi(fields[1]) if err != nil { log.Println("SwapFree conversion to int failed: ", err) } else { m.SwapFreeKb = int64(v) flag = true } } } if flag { // log.Println("[monitorMem] -> sending new metric") memStream <- &m } } } } } func (r Run) logMem(ctx context.Context, memStream chan *Mem) error { // log.Println("[logMem] -> Fired") f, err := os.Create(path.Join(r.workDir(), LF_MEM)) if err != nil { return wrapErr(err, "OpenFile failed") } defer f.Close() outer: for { select { case <-ctx.Done(): break outer case lav := <-memStream: if lav == nil { break outer } log.Println("[logMem] -> received new metric") bs, err := json.Marshal(lav) if err != nil { r.log.Println(wrapErr(err, "json.Marshal failed")) } else { log.Println("[logMem] -> writing to log file") f.WriteString(fmt.Sprintf("%s\n", string(bs))) } } } return nil } type DiskLog struct { TS int64 `json:"timestamp_millis"` MajorNum string `json:"major_num"` MinorNum string `json:"minor_num"` DeviceName string `json:"device_name"` ReadsCompleted string `json:"reads_completed_successfully"` ReadsMerged string `json:"reads_merged"` SectorsRead string `json:"sectors_read"` TimeSpentReadingMs string `json:"time_spent_reading_ms"` WritesCompleted string `json:"writes_completed"` WriteMerged string `json:"writes_merged"` SectorsWritten string `json:"sectors_written"` TimeSpentWritingMs string `json:"time_spent_writing"` IosInProgress string `json:"ios_currently_in_progress"` TimeSpentIosMs string `json:"time_spent_doing_ios_ms"` WeightedTimeSpentDoingIosMs string `json:"weighted_time_spent_doing_ios_ms"` DiscardsCompleted string `json:"discards_completed_successfully"` DiscardsMerged string `json:"discards_merged"` SectorsDiscarded string `json:"sectors_discarded"` TimeSpentDiscardingMs string `json:"time_sspent_discarding"` FlushRequestsCompleted string `json:"flush_requests_completed_successfully"` TimeSpentFlushingMs string `json:"time_spent_flushing"` } func (r Run) monitorDisk(ctx context.Context, diskChan chan *DiskLog) { ticker := time.NewTicker(time.Second) ctx1, cancel := context.WithCancel(context.TODO()) outer: for { select { case <-ctx.Done(): cancel() ticker.Stop() diskChan <- nil break outer case <-ticker.C: command := exec.CommandContext(ctx1, "cat", "/proc/diskstats") bs, err := command.CombinedOutput() if err != nil { log.Println(wrapErr(err, "CombinedOutput failed")) continue outer } for _, line := range strings.Split(string(bs), "\n") { fields := strings.Fields(line) // Get only sata or nvme disks if len(fields) == 0 { continue } if !strings.Contains(fields[2], "sd") && !strings.Contains(fields[2], "nvme") { continue } dl := DiskLog{ TS: time.Now().UnixMilli(), MajorNum: fields[0], MinorNum: fields[1], DeviceName: fields[2], ReadsCompleted: fields[3], ReadsMerged: fields[4], SectorsRead: fields[5], TimeSpentReadingMs: fields[6], WritesCompleted: fields[7], WriteMerged: fields[8], SectorsWritten: fields[9], TimeSpentWritingMs: fields[10], IosInProgress: fields[11], TimeSpentIosMs: fields[12], WeightedTimeSpentDoingIosMs: fields[13], } lf := len(fields) // Kernel 4.18+ will have the following fields if lf >= 18 { dl.DiscardsCompleted = fields[14] dl.DiscardsMerged = fields[15] dl.SectorsDiscarded = fields[16] dl.TimeSpentDiscardingMs = fields[17] } // Kernel 5.5+ further have the following fields if lf >= 20 { dl.FlushRequestsCompleted = fields[18] dl.TimeSpentFlushingMs = fields[19] } diskChan <- &dl } } } } func (r Run) logDisk(ctx context.Context, diskChan chan *DiskLog) error { f, err := os.Create(path.Join(r.workDir(), LF_DISK)) if err != nil { return wrapErr(err, "os.Create failed") } defer f.Close() outer: for { select { case <-ctx.Done(): break outer case lav := <-diskChan: if lav == nil { break outer } log.Println("[logDisk] -> received new metric") bs, err := json.Marshal(lav) if err != nil { r.log.Println(wrapErr(err, "json.Marshal failed")) } else { log.Println("[logDisk] -> writing to log file") f.WriteString(fmt.Sprintf("%s\n", string(bs))) } } } return nil } func (r Run) persistInMemDir() { // log.Println("[persistInMemDir] -> Fired") if !r.logToMem { return } log.Printf("moving: \"%s\" to \"%s\"\n", r.workDir(), r.wd) cmd := exec.Command("mv", r.workDir(), fmt.Sprintf("%s/", r.wd)) if _, err := cmd.CombinedOutput(); err != nil { r.log.Println(wrapErr(err, fmt.Sprintf("moving from shared memory to path: \"%s\" failed", r.wd))) } } func (r Run) monitor(ctx context.Context) { // log.Println("[monitor] -> Fired") wg := sync.WaitGroup{} // save pid file pf, err := os.Create(path.Join(r.workDir(), "monitor.pid")) if err != nil { r.log.Println("error creating pid file") return } pf.WriteString(fmt.Sprintf("%d", os.Getpid())) // strace tctx, tcancel := context.WithCancel(ctx) wg.Add(1) go func() { defer wg.Done() r.startStrace(tctx) }() // CPU ============================== cpuStream := make(chan *LoadAvg, 1) cpuCtx, cpuCancel := context.WithCancel(ctx) logCpuCtx, logCpuCancel := context.WithCancel(ctx) wg.Add(1) go func() { defer wg.Done() r.monitorCPU(cpuCtx, cpuStream) }() wg.Add(1) go func() { defer wg.Done() if err := r.logCPU(logCpuCtx, cpuStream); err != nil { r.log.Println(wrapErr(err)) } }() // RAM memStream := make(chan *Mem, 1) memCtx, memCancel := context.WithCancel(ctx) logMemCtx, logMemCancel := context.WithCancel(ctx) wg.Add(1) go func() { defer wg.Done() r.monitorMem(memCtx, memStream) }() wg.Add(1) go func() { defer wg.Done() if err := r.logMem(logMemCtx, memStream); err != nil { r.log.Println(wrapErr(err)) } }() // Disk diskChan := make(chan *DiskLog, 20) diskCtx, diskCancel := context.WithCancel(ctx) logDiskCtx, logDiskCancel := context.WithCancel(ctx) wg.Add(1) go func() { defer wg.Done() r.monitorDisk(diskCtx, diskChan) }() wg.Add(1) go func() { defer wg.Done() if err := r.logDisk(logDiskCtx, diskChan); err != nil { r.log.Println(wrapErr(err)) } }() <-ctx.Done() tcancel() cpuCancel() logCpuCancel() memCancel() logMemCancel() diskCancel() logDiskCancel() wg.Wait() r.persistInMemDir() } ================================================ FILE: VMBackup/main/ExtensionErrorCodeHelper.py ================================================ from Utils import Status class ExtensionErrorCodeEnum(): success_appconsistent = 0 success = 1 error = 2 SuccessAlreadyProcessedInput = 3 ExtensionTempTerminalState = 4 error_parameter = 11 error_12 = 12 error_wrong_time = 13 error_same_taskid = 14 error_http_failure = 15 FailedHandlerGuestAgentCertificateNotFound = 16 #error_upload_status_blob = 16 FailedInvalidDataDiskLunList = 17 FailedRetryableSnapshotFailedNoNetwork = 76 FailedHostSnapshotRemoteServerError = 556 FailedSnapshotLimitReached = 85 FailedRetryableSnapshotRateExceeded = 173 FailedRetryableSnapshotFailedRestrictedNetwork = 761 FailedRetryableFsFreezeFailed = 121 FailedRetryableFsFreezeTimeout = 122 FailedRetryableUnableToOpenMount = 123 FailedSafeFreezeBinaryNotFound = 124 FailedPrepostPreScriptFailed = 300 FailedPrepostPostScriptFailed = 301 FailedPrepostPreScriptNotFound = 302 FailedPrepostPostScriptNotFound = 303 FailedPrepostPluginhostConfigParsing = 304 FailedPrepostPluginConfigParsing = 305 FailedPrepostPreScriptPermissionError = 306 FailedPrepostPostScriptPermissionError = 307 FailedPrepostPreScriptTimeout = 308 FailedPrepostPostScriptTimeout = 309 FailedPrepostPluginhostPreTimeout = 310 FailedPrepostPluginhostPostTimeout = 311 FailedPrepostCheckSumMismatch = 312 FailedPrepostPluginhostConfigNotFound = 313 FailedPrepostPluginhostConfigPermissionError = 314 FailedPrepostPluginhostConfigOwnershipError = 315 FailedPrepostPluginConfigNotFound = 316 FailedPrepostPluginConfigPermissionError = 317 FailedPrepostPluginConfigOwnershipError = 318 FailedGuestAgentInvokedCommandTooLate = 402 FailedWorkloadPreError = 500 FailedWorkloadConfParsingError = 501 FailedWorkloadInvalidRole = 502 FailedWorkloadInvalidWorkloadName = 503 FailedWorkloadPostError = 504 FailedWorkloadAuthorizationMissing = 505 FailedWorkloadConnectionError = 506 FailedWorkloadIPCDirectoryMissing = 507 FailedWorkloadDatabaseStatusChanged = 508 FailedWorkloadQuiescingError = 509 FailedWorkloadQuiescingTimeout = 510 FailedWorkloadDatabaseInNoArchiveLog = 511 FailedWorkloadLogModeChanged = 512 class ExtensionErrorCodeHelper: ExtensionErrorCodeDict = { ExtensionErrorCodeEnum.success_appconsistent : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.success : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.ExtensionTempTerminalState : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.error : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.error_12 : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.SuccessAlreadyProcessedInput : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.FailedRetryableSnapshotRateExceeded : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList : Status.ExtVmHealthStateEnum.green, ExtensionErrorCodeEnum.FailedSafeFreezeBinaryNotFound: Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedRetryableFsFreezeFailed : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedRetryableFsFreezeTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedRetryableUnableToOpenMount : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.error_parameter : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedHandlerGuestAgentCertificateNotFound : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPreScriptFailed : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPostScriptFailed : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPreScriptNotFound : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPostScriptNotFound : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigParsing : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginConfigParsing : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPreScriptPermissionError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPostScriptPermissionError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPreScriptTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPostScriptTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostPreTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostPostTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostCheckSumMismatch : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigNotFound : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigPermissionError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigOwnershipError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginConfigNotFound : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginConfigPermissionError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedPrepostPluginConfigOwnershipError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.error_http_failure : Status.ExtVmHealthStateEnum.red, ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedRestrictedNetwork : Status.ExtVmHealthStateEnum.red, ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedNoNetwork : Status.ExtVmHealthStateEnum.red, ExtensionErrorCodeEnum.FailedSnapshotLimitReached : Status.ExtVmHealthStateEnum.red, ExtensionErrorCodeEnum.FailedGuestAgentInvokedCommandTooLate : Status.ExtVmHealthStateEnum.red, ExtensionErrorCodeEnum.FailedWorkloadPreError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadConfParsingError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadInvalidRole : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadInvalidWorkloadName : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadPostError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadAuthorizationMissing : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadConnectionError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadIPCDirectoryMissing : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadDatabaseStatusChanged : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadQuiescingError : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadQuiescingTimeout : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadDatabaseInNoArchiveLog : Status.ExtVmHealthStateEnum.yellow, ExtensionErrorCodeEnum.FailedWorkloadLogModeChanged : Status.ExtVmHealthStateEnum.yellow } ExtensionErrorCodeNameDict = { ExtensionErrorCodeEnum.success : "success", ExtensionErrorCodeEnum.success_appconsistent : "success_appconsistent", ExtensionErrorCodeEnum.ExtensionTempTerminalState : "ExtensionTempTerminalState", ExtensionErrorCodeEnum.error : "error", ExtensionErrorCodeEnum.error_12 : "error_12", ExtensionErrorCodeEnum.SuccessAlreadyProcessedInput : "SuccessAlreadyProcessedInput", ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList : "FailedInvalidDataDiskLunList", ExtensionErrorCodeEnum.FailedRetryableFsFreezeFailed : "FailedRetryableFsFreezeFailed", ExtensionErrorCodeEnum.FailedRetryableFsFreezeTimeout : "FailedRetryableFsFreezeTimeout", ExtensionErrorCodeEnum.FailedRetryableUnableToOpenMount : "FailedRetryableUnableToOpenMount", ExtensionErrorCodeEnum.error_parameter : "error_parameter", ExtensionErrorCodeEnum.FailedHandlerGuestAgentCertificateNotFound : "FailedHandlerGuestAgentCertificateNotFound", ExtensionErrorCodeEnum.FailedSafeFreezeBinaryNotFound : "FailedSafeFreezeBinaryNotFound", ExtensionErrorCodeEnum.FailedPrepostPreScriptFailed : "FailedPrepostPreScriptFailed", ExtensionErrorCodeEnum.FailedPrepostPostScriptFailed : "FailedPrepostPostScriptFailed", ExtensionErrorCodeEnum.FailedPrepostPreScriptNotFound : "FailedPrepostPreScriptNotFound", ExtensionErrorCodeEnum.FailedPrepostPostScriptNotFound : "FailedPrepostPostScriptNotFound", ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigParsing : "FailedPrepostPluginhostConfigParsing", ExtensionErrorCodeEnum.FailedPrepostPluginConfigParsing : "FailedPrepostPluginConfigParsing", ExtensionErrorCodeEnum.FailedPrepostPreScriptPermissionError : "FailedPrepostPreScriptPermissionError", ExtensionErrorCodeEnum.FailedPrepostPostScriptPermissionError : "FailedPrepostPostScriptPermissionError", ExtensionErrorCodeEnum.FailedPrepostPreScriptTimeout : "FailedPrepostPreScriptTimeout", ExtensionErrorCodeEnum.FailedPrepostPostScriptTimeout : "FailedPrepostPostScriptTimeout", ExtensionErrorCodeEnum.FailedPrepostPluginhostPreTimeout : "FailedPrepostPluginhostPreTimeout", ExtensionErrorCodeEnum.FailedPrepostPluginhostPostTimeout : "FailedPrepostPluginhostPostTimeout", ExtensionErrorCodeEnum.FailedPrepostCheckSumMismatch : "FailedPrepostCheckSumMismatch", ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigNotFound : "FailedPrepostPluginhostConfigNotFound", ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigPermissionError : "FailedPrepostPluginhostConfigPermissionError", ExtensionErrorCodeEnum.FailedPrepostPluginhostConfigOwnershipError : "FailedPrepostPluginhostConfigOwnershipError", ExtensionErrorCodeEnum.FailedPrepostPluginConfigNotFound : "FailedPrepostPluginConfigNotFound", ExtensionErrorCodeEnum.FailedPrepostPluginConfigPermissionError : "FailedPrepostPluginConfigPermissionError", ExtensionErrorCodeEnum.FailedPrepostPluginConfigOwnershipError : "FailedPrepostPluginConfigOwnershipError", ExtensionErrorCodeEnum.error_http_failure : "error_http_failure", ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedRestrictedNetwork : "FailedRetryableSnapshotFailedRestrictedNetwork", ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedNoNetwork : "FailedRetryableSnapshotFailedNoNetwork", ExtensionErrorCodeEnum.FailedHostSnapshotRemoteServerError : "FailedHostSnapshotRemoteServerError", ExtensionErrorCodeEnum.FailedSnapshotLimitReached : "FailedSnapshotLimitReached", ExtensionErrorCodeEnum.FailedGuestAgentInvokedCommandTooLate : "FailedGuestAgentInvokedCommandTooLate", ExtensionErrorCodeEnum.FailedWorkloadPreError : "FailedWorkloadPreError", ExtensionErrorCodeEnum.FailedWorkloadConfParsingError : "FailedWorkloadConfParsingError", ExtensionErrorCodeEnum.FailedWorkloadInvalidRole : "FailedWorkloadInvalidRole", ExtensionErrorCodeEnum.FailedWorkloadInvalidWorkloadName : "FailedWorkloadInvalidWorkloadName", ExtensionErrorCodeEnum.FailedWorkloadPostError : "FailedWorkloadPostError", ExtensionErrorCodeEnum.FailedWorkloadAuthorizationMissing : "FailedWorkloadAuthorizationMissing", ExtensionErrorCodeEnum.FailedWorkloadConnectionError : "FailedWorkloadConnectionError", ExtensionErrorCodeEnum.FailedWorkloadIPCDirectoryMissing : "FailedWorkloadIPCDirectoryMissing", ExtensionErrorCodeEnum.FailedWorkloadDatabaseStatusChanged : "FailedWorkloadDatabaseStatusChanged", ExtensionErrorCodeEnum.FailedWorkloadQuiescingError : "FailedWorkloadQuiescingError", ExtensionErrorCodeEnum.FailedWorkloadQuiescingTimeout : "FailedWorkloadQuiescingTimeout", ExtensionErrorCodeEnum.FailedWorkloadDatabaseInNoArchiveLog : "FailedWorkloadDatabaseInNoArchiveLog", ExtensionErrorCodeEnum.FailedWorkloadLogModeChanged : "FailedWorkloadLogModeChanged" } @staticmethod def StatusCodeStringBuilder(ExtErrorCodeEnum): return " StatusCode." + ExtensionErrorCodeHelper.ExtensionErrorCodeNameDict[ExtErrorCodeEnum] + "," ================================================ FILE: VMBackup/main/HttpUtil.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import datetime import traceback try: import httplib as httplibs except ImportError: import http.client as httplibs import shlex import subprocess import sys from common import CommonVariables from subprocess import * from Utils.WAAgentUtil import waagent import Utils.HandlerUtil import sys class HttpUtil(object): """description of class""" __instance = None """Singleton class initialization""" def __new__(cls, hutil): if(cls.__instance is None): hutil.log("Creating HttpUtil") cls.__instance = super(HttpUtil, cls).__new__(cls) Config = None cls.__instance.proxyHost = None cls.__instance.proxyPort = None try: waagent.MyDistro = waagent.GetMyDistro() Config = waagent.ConfigurationProvider(None) except Exception as e: errorMsg = "Failed to construct ConfigurationProvider, which may be due to the old code with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) hutil.log(errorMsg) Config = None cls.__instance.logger = hutil if Config != None: cls.__instance.proxyHost = Config.get("HttpProxy.Host") cls.__instance.proxyPort = Config.get("HttpProxy.Port") cls.__instance.tmpFile = './tmp_file_FD76C85E-406F-4CFA-8EB0-CF18B123365C' else: cls.__instance.logger = hutil cls.__instance.logger.log("Returning HttpUtil") return cls.__instance """ snapshot also called this. so we should not write the file/read the file in this method. """ def CallUsingCurl(self,method,sasuri_obj,data,headers): header_str = "" for key, value in headers.iteritems(): header_str = header_str + '-H ' + '"' + str(key) + ':' + str(value) + '"' if(self.proxyHost == None or self.proxyPort == None): commandToExecute = 'curl --request PUT --connect-timeout 10 --data-binary @-' + ' ' + header_str + ' "' + sasuri_obj.scheme + '://' + sasuri_obj.hostname + sasuri_obj.path + '?' + sasuri_obj.query + '"' + ' -v' else: commandToExecute = 'curl --request PUT --connect-timeout 10 --data-binary @-' + ' ' + header_str + ' "' + sasuri_obj.scheme + '://' + sasuri_obj.hostname + sasuri_obj.path + '?' + sasuri_obj.query + '"'\ + '--proxy ' + self.proxyHost + ':' + self.proxyPort + ' -v' args =Utils.HandlerUtil.HandlerUtility.split(self.logger, commandToExecute.encode('ascii')) proc = Popen(args,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE) proc.stdin.write(data) curlResult,err = proc.communicate() returnCode = proc.wait() self.logger.log("curl error is: " + str(err)) self.logger.log("curl return code is : " + str(returnCode)) # what if the curl is returned successfully, but the http response is # 403 if(returnCode == 0): return CommonVariables.success else: return CommonVariables.error_http_failure def Call(self, method, sasuri_obj, data, headers, fallback_to_curl = False): try: result, resp, errorMsg = self.HttpCallGetResponse(method, sasuri_obj, data, headers) self.logger.log("HttpUtil Call : result: " + str(result) + ", errorMsg: " + str(errorMsg)) if(result == CommonVariables.success and resp != None): self.logger.log("resp-header: " + str(resp.getheaders())) else: self.logger.log("Http connection response is None") responseBody = resp.read() self.logger.log(" resp status: " + str(resp.status)) if(responseBody is not None): self.logger.log("responseBody: " + (responseBody).decode('utf-8-sig')) if(resp.status == 200 or resp.status == 201): return CommonVariables.success else: return resp.status except Exception as e: errorMsg = "Failed to call http with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg) if(fallback_to_curl): return self.CallUsingCurl(method,sasuri_obj,data,headers) else: return CommonVariables.error_http_failure def HttpCallGetResponse(self, method, sasuri_obj, data, headers , responseBodyRequired = False, isHostCall = False): result = CommonVariables.error_http_failure resp = None responeBody = "" errorMsg = None responseBody = None try: resp = None self.logger.log("Entered HttpCallGetResponse, isHostCall: " + str(isHostCall)) if(isHostCall or self.proxyHost == None or self.proxyPort != None): if(isHostCall): connection = httplibs.HTTPConnection(sasuri_obj.hostname, timeout = 40) # making call with port 80 to make it http call else: connection = httplibs.HTTPSConnection(sasuri_obj.hostname, timeout = 10) self.logger.log("Details of sas uri object hostname: " + str(sasuri_obj.hostname) + " path: " + str(sasuri_obj.path)) connection.request(method=method, url=(sasuri_obj.path + '?' + sasuri_obj.query), body=data, headers = headers) resp = connection.getresponse() if(responseBodyRequired): responeBody = resp.read().decode('utf-8-sig') connection.close() else: connection = httplibs.HTTPSConnection(self.proxyHost, self.proxyPort, timeout = 10) connection.set_tunnel(sasuri_obj.hostname, 443) # If proxy is used, full url is needed. path = "https://{0}:{1}{2}".format(sasuri_obj.hostname, 443, (sasuri_obj.path + '?' + sasuri_obj.query)) connection.request(method=method, url=(path), body=data, headers=headers) resp = connection.getresponse() connection.close() result = CommonVariables.success except Exception as e: errorMsg = str(datetime.datetime.utcnow()) + " Failed to call http with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg) if sys.version[0] == 2 and sys.version[1] == 6: self.CallUsingCurl(method,sasuri_obj,data,headers) if(responseBodyRequired): return result, resp, errorMsg, responeBody else: return result, resp, errorMsg ================================================ FILE: VMBackup/main/IaaSExtensionSnapshotService/README.md ================================================ The systemd process manages the lifecycle of the service, including starting, stopping, restarting, and keeping track of whether it’s running or not. A service is registered typically by placing its .service file in /etc/systemd/system/ To start: sudo systemctl start ExecStart defines what executable or script is launched when the service starts. Systemd spawns a child process to run the command in ExecStart Typically, daemons write their PID to a PID file themselves (e.g., /var/run/.pid) to track their process. ================================================ FILE: VMBackup/main/IaaSExtensionSnapshotService/SnapshotServiceConstants.py ================================================ class SnapshotServiceConstants: service_name = "Microsoft.Azure.RecoveryServices.VMSnapshotLinux.service" config_section = 'IaaSExtensionSnapshotService' pid_file = "VMSnapshotLinux.pid" HOST_IP_ADDRESS = "168.63.129.16" GET_SNAPSHOT_REQUESTS_URI = "http://{0}/xdisksvc/snapshotrequest".format(HOST_IP_ADDRESS) START_SNAPSHOT_REQUESTS_URI = "http://{0}/xdisksvc/startsnapshots".format(HOST_IP_ADDRESS) END_SNAPSHOT_REQUESTS_URI = "http://{0}/xdisksvc/endsnapshots".format(HOST_IP_ADDRESS) SERVICE_POLLING_INTERVAL_IN_SECS = 300 EXTENSION_TIMEOUT_IN_MINS = 10 ================================================ FILE: VMBackup/main/IaaSExtensionSnapshotService/SnapshotServiceContracts.py ================================================ import json class GetSnapshotResponseBody: def __init__(self, snapshotId, diskInfo=None, extensionSettings=None): self.snapshotId = snapshotId self.diskInfo = diskInfo self.extensionSettings = extensionSettings def convertToDictionary(self): return dict( snapshotId=self.snapshotId, diskInfo=self.diskInfo.convertToDictionary() if self.diskInfo else None, extensionSettings=self.extensionSettings ) class StartSnapshotHostResponseBody: def __init__(self, snapshotId, error=None): self.snapshotId = snapshotId self.error = error def convertToDictionary(self): return dict( snapshotId=self.snapshotId, error=self.error.convertToDictionary() if self.error else None ) class StartSnapshotHostRequestBody: def __init__(self, snapshotId): self.snapshotId = snapshotId def serialize_to_json_string(self): return json.dumps(self.convertToDictionary()) def convertToDictionary(self): return dict(snapshotId=self.snapshotId) class EndSnapshotHostRequestBody: def __init__(self, snapshotId, error=None, provisioningDetails=None): self.snapshotId = snapshotId self.error = error self.provisioningDetails = provisioningDetails def serialize_to_json_string(self): return json.dumps(self.convertToDictionary()) def convertToDictionary(self): return dict( snapshotId=self.snapshotId, error=self.error.convertToDictionary() if self.error else None, provisioningDetails=self.provisioningDetails ) class EndSnapshotHostResponseBody: def __init__(self, snapshotId, error=None): self.snapshotId = snapshotId self.error = error def convertToDictionary(self): return dict( snapshotId=self.snapshotId, error=self.error.convertToDictionary() if self.error else None ) class Error: def __init__(self, code, message=None): self.code = code self.message = message def convertToDictionary(self): return dict( code=self.code, message=self.message ) class DiskInfo: def __init__(self, dataDiskInfo=None, isOSDiskIncluded=False): self.dataDiskInfo = dataDiskInfo or [] self.isOSDiskIncluded = isOSDiskIncluded def convertToDictionary(self): return dict( dataDiskInfo=[disk.convertToDictionary() for disk in self.dataDiskInfo], isOSDiskIncluded=self.isOSDiskIncluded ) class DataDiskInfo: def __init__(self, controllerType, controllerId, lunId): self.controllerType = controllerType self.controllerId = controllerId self.lunId = lunId def convertToDictionary(self): return dict( controllerType=self.controllerType, controllerId=self.controllerId, lunId=self.lunId ) class XDiskSvcError: def __init__(self, code, message=None): self.code = code self.message = message def convertToDictionary(self): return dict( code=self.code, message=self.message ) class ProvisioningDetails: def __init__(self, code, vmHealthInfo=None, storageDetails=None, message=None): self.code = code self.vmHealthInfo = vmHealthInfo self.storageDetails = storageDetails self.message = message def convertToDictionary(self): return dict( code=self.code, vmHealthInfo=self.vmHealthInfo.convertToDictionary() if self.vmHealthInfo else None, storageDetails=self.storageDetails.convertToDictionary() if self.storageDetails else None, message=self.message ) ================================================ FILE: VMBackup/main/IaaSExtensionSnapshotService/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMBackup/main/IaaSExtensionSnapshotService/service_metadata.json ================================================ { "Unit": { "Description": "Long running Snapshot service for Microsoft Azure Restore Points", "After": "multi-user.target" }, "Service": { "Type": "simple", "Restart": "always", "WorkingDirectory": "../..", "ExecStart": ["/usr/bin/env", "python", "main/IaaSExtensionSnapshotService/PollingService.py"] }, "Install": { "WantedBy": "multi-user.target" } } ================================================ FILE: VMBackup/main/LogSeverity.json ================================================ { "EventLogLevel": 2 } ================================================ FILE: VMBackup/main/MachineIdentity.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess import xml import xml.dom.minidom class MachineIdentity: def __init__(self): self.store_identity_file = './machine_identity_FD76C85E-406F-4CFA-8EB0-CF18B123365C' def current_identity(self): identity = None file = None try: if os.path.exists("/var/lib/waagent/HostingEnvironmentConfig.xml"): file = open("/var/lib/waagent/HostingEnvironmentConfig.xml",'r') xmlText = file.read() dom = xml.dom.minidom.parseString(xmlText) deployment = dom.getElementsByTagName("Role") identity=deployment[0].getAttribute("guid") finally: if file != None: if file.closed == False: file.close() return identity def save_identity(self): file = None try: file = open(self.store_identity_file,'w') machine_identity = self.current_identity() if( machine_identity != None ): file.write(machine_identity) finally: if file != None: if file.closed == False: file.close() def stored_identity(self): identity_stored = None file = None try: if(os.path.exists(self.store_identity_file)): file = open(self.store_identity_file,'r') identity_stored = file.read() finally: if file != None: if file.closed == False: file.close() return identity_stored ================================================ FILE: VMBackup/main/PluginHost.py ================================================ import time import sys import os import threading import platform try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers from common import CommonVariables from pwd import getpwuid from stat import * import traceback # [pre_post] # "timeout" : (in seconds), # # .... other params ... # # "pluginName0" : "oracle_plugin", the python plugin file will have same name # "pluginPath0" : "/abc/xyz/" # "pluginConfigPath0" : "sdf/sdf/abcd.json" # # # errorcode policy # errorcode = 0 (CommonVariables.PrePost_PluginStatus_Successs), means success, script runs without error, warnings maybe possible # errorcode = 5 (CommonVariables.PrePost_PluginStatus_Timeout), means timeout # errorcode = 10 (CommonVariables.PrePost_PluginStatus_ConfigNotFound), config file not found # errorcode = process return code, means bash script encountered some other error, like 127 for script not found class PluginHostError(object): def __init__(self, errorCode, pluginName): self.errorCode = errorCode self.pluginName = pluginName def __str__(self): return 'Plugin :- ', self.pluginName , ' ErrorCode :- ' + str(self.errorCode) class PluginHostResult(object): def __init__(self): self.errors = [] self.anyScriptFailed = False self.continueBackup = True self.errorCode = 0 self.fileCode = [] self.filePath = [] def __str__(self): errorStr = '' for error in self.errors: errorStr += (str(error)) + '\n' errorStr += 'Final Error Code :- ' + str(self.errorCode) + '\n' errorStr += 'Any script Failed :- ' + str(self.anyScriptFailed) + '\n' errorStr += 'Continue Backup :- ' + str(self.continueBackup) + '\n' return errorStr class PluginHost(object): """ description of class """ def __init__(self, logger): self.logger = logger self.modulesLoaded = False self.configLocation = '/etc/azure/VMSnapshotPluginHost.conf' self.timeoutInSeconds = 1800 self.plugins = [] self.pluginName = [] self.noOfPlugins = 0 self.preScriptCompleted = [] self.preScriptResult = [] self.postScriptCompleted = [] self.postScriptResult = [] self.pollTime = 3 def pre_check(self): self.logger.log('Loading script modules now...',True,'Info') errorCode = CommonVariables.PrePost_PluginStatus_Success dobackup = True fsFreeze_on = True # NS-BSD is already hardened, no checks and no freeze if 'NS-BSD' in platform.system(): return errorCode, dobackup, False if not os.path.isfile(self.configLocation): self.logger.log('Plugin host Config file does not exist in the location ' + self.configLocation, True) self.configLocation = './main/VMSnapshotPluginHost.conf' permissions = self.get_permissions(self.configLocation) if not os.path.isfile(self.configLocation): self.logger.log('Plugin host Config file does not exist in the location ' + self.configLocation, True) errorCode =CommonVariables.FailedPrepostPluginhostConfigNotFound elif not (int(permissions[1]) == 0 or int(permissions[1]) == 4) or not (int(permissions[2]) == 0 or int(permissions[2]) == 4): self.logger.log('Plugin host Config file does not have desired permissions', True, 'Error') errorCode = CommonVariables.FailedPrepostPluginhostConfigPermissionError elif not self.find_owner(self.configLocation) == 'root': self.logger.log('The owner of the Plugin host Config file ' + self.configLocation + ' is ' + self.find_owner(self.configLocation) + ' but not root', True, 'Error') errorCode = CommonVariables.FailedPrepostPluginhostConfigPermissionError else : errorCode,dobackup,fsFreeze_on = self.load_modules() return errorCode,dobackup,fsFreeze_on def load_modules(self): # Imports all plugin modules using the information in config.json # and initializes basic class variables associated with each plugin len = 0 errorCode = CommonVariables.PrePost_PluginStatus_Success dobackup = True fsFreeze_on = True try: self.logger.log('config file: '+str(self.configLocation),True,'Info') config = ConfigParsers.ConfigParser() config.read(self.configLocation) if (config.has_option('pre_post', 'timeoutInSeconds')): self.timeoutInSeconds = min(int(config.get('pre_post','timeoutInSeconds')),self.timeoutInSeconds) if (config.has_option('pre_post', 'numberOfPlugins')): len = int(config.get('pre_post','numberOfPlugins')) self.logger.log('timeoutInSeconds: '+str(self.timeoutInSeconds),True,'Info') self.logger.log('numberOfPlugins: '+str(len),True,'Info') while len > 0: pname = config.get('pre_post','pluginName'+str(self.noOfPlugins)) ppath = config.get('pre_post','pluginPath'+str(self.noOfPlugins)) pcpath = config.get('pre_post','pluginConfigPath'+str(self.noOfPlugins)) self.logger.log('Name of the Plugin is ' + pname, True) self.logger.log('Plugin config path is ' + pcpath, True) errorCode = CommonVariables.PrePost_PluginStatus_Success dobackup = True if os.path.isfile(pcpath): permissions = self.get_permissions(pcpath) if (int(permissions[0]) %2 == 1) or int(permissions[1]) > 0 or int(permissions[2]) > 0: self.logger.log('Plugin Config file does not have desired permissions', True, 'Error') errorCode = CommonVariables.FailedPrepostPluginConfigPermissionError if not self.find_owner(pcpath) == 'root': self.logger.log('The owner of the Plugin Config file ' + pcpath + ' is ' + self.find_owner(pcpath) + ' but not root', True, 'Error') errorCode = CommonVariables.FailedPrepostPluginConfigPermissionError else: self.logger.log('Plugin host file does not exist in the location ' + pcpath, True, 'Error') errorCode = CommonVariables.FailedPrepostPluginConfigNotFound if(errorCode == CommonVariables.PrePost_PluginStatus_Success): sys.path.append(ppath) plugin = __import__(pname) self.plugins.append(plugin.ScriptRunner(logger=self.logger,name=pname,configPath=pcpath,maxTimeOut=self.timeoutInSeconds)) errorCode,dobackup,fsFreeze_on, self.pollTime = self.plugins[self.noOfPlugins].validate_scripts() self.logger.log('Validate Scripts output: errorCode - {0} dobackup - {1} fsFreeze_on - {2} pollTime - {3}'.format(errorCode, dobackup, fsFreeze_on, self.pollTime), True) self.noOfPlugins = self.noOfPlugins + 1 self.pluginName.append(pname) self.preScriptCompleted.append(False) self.preScriptResult.append(None) self.postScriptCompleted.append(False) self.postScriptResult.append(None) len = len - 1 if self.noOfPlugins != 0: self.modulesLoaded = True except Exception as err: errMsg = 'Error in reading PluginHost config file : %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') errorCode = CommonVariables.FailedPrepostPluginhostConfigParsing return errorCode,dobackup,fsFreeze_on def find_owner(self, filename): file_owner = '' try: file_owner = getpwuid(os.stat(filename).st_uid).pw_name except Exception as err: errMsg = 'Error in fetching owner of the file : ' + filename + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') return file_owner def get_permissions(self, filename): permissions = '777' try: permissions = oct(os.stat(filename)[ST_MODE])[-3:] self.logger.log('Permissions of the file ' + filename + ' are ' + permissions,True) except Exception as err: errMsg = 'Error in fetching permissions of the file : ' + filename + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') return permissions def pre_script(self): # Runs pre_script() for all plugins and maintains a timer result = PluginHostResult() curr = 0 for plugin in self.plugins: t1 = threading.Thread(target=plugin.pre_script, args=(curr, self.preScriptCompleted, self.preScriptResult)) t1.start() curr = curr + 1 flag = True for i in range(0, int(self.timeoutInSeconds/self.pollTime) + 2): #waiting 10 more seconds to escape race condition between Host and script timing out time.sleep(self.pollTime) flag = True for j in range(0,self.noOfPlugins): flag = flag & self.preScriptCompleted[j] if flag: break continueBackup = True #Plugin timed out if not flag: ecode = CommonVariables.FailedPrepostPluginhostPreTimeout result.anyScriptFailed = True presult = PluginHostError(errorCode = ecode, pluginName = self.pluginName[j]) result.errors.append(presult) else: for j in range(0,self.noOfPlugins): ecode = CommonVariables.FailedPrepostPluginhostPreTimeout continueBackup = continueBackup & self.preScriptResult[j].continueBackup if self.preScriptCompleted[j]: ecode = self.preScriptResult[j].errorCode if ecode != CommonVariables.PrePost_PluginStatus_Success: result.anyScriptFailed = True presult = PluginHostError(errorCode = ecode, pluginName = self.pluginName[j]) result.errors.append(presult) result.continueBackup = continueBackup self.logger.log('Finished prescript execution from PluginHost side. Continue Backup: '+str(continueBackup),True,'Info') return result def post_script(self): # Runs post_script() for all plugins and maintains a timer result = PluginHostResult() if not self.modulesLoaded: return result self.logger.log('Starting postscript for all modules.',True,'Info') curr = 0 for plugin in self.plugins: t1 = threading.Thread(target=plugin.post_script, args=(curr, self.postScriptCompleted, self.postScriptResult)) t1.start() curr = curr + 1 flag = True for i in range(0, int(self.timeoutInSeconds/self.pollTime) + 2): #waiting 10 more seconds to escape race condition between Host and script timing out time.sleep(self.pollTime) flag = True for j in range(0,self.noOfPlugins): flag = flag & self.postScriptCompleted[j] if flag: break continueBackup = True #Plugin timed out if not flag: ecode = CommonVariables.FailedPrepostPluginhostPostTimeout result.anyScriptFailed = True presult = PluginHostError(errorCode = ecode, pluginName = self.pluginName[j]) result.errors.append(presult) else: for j in range(0,self.noOfPlugins): ecode = CommonVariables.FailedPrepostPluginhostPostTimeout continueBackup = continueBackup & self.postScriptResult[j].continueBackup if self.postScriptCompleted[j]: ecode = self.postScriptResult[j].errorCode if ecode != CommonVariables.PrePost_PluginStatus_Success: result.anyScriptFailed = True presult = PluginHostError(errorCode = ecode, pluginName = self.pluginName[j]) result.errors.append(presult) result.continueBackup = continueBackup self.logger.log('Finished postscript execution from PluginHost side. Continue Backup: '+str(continueBackup),True,'Info') return result ================================================ FILE: VMBackup/main/ScriptRunner.py ================================================ import json import subprocess import time import os from pwd import getpwuid from stat import * from common import CommonVariables import traceback from Utils import HandlerUtil # config.json --------structure--------- # { # "pluginName" : "oracleLinux", # "timeoutInSeconds" : (in seconds), # "continueBackupOnFailure" : true/false, # # ... other config params ... # # "preScriptLocation" : "/abc/xyz.sh" # "postScriptLocation" : "/abc/def.sh" # "preScriptNoOfRetries" : 3, # "postScriptNoOfRetries" : 2, # "preScriptParams" : [ # ... all params to be passed to prescript ... # ], # "postScriptParams" : [ # ... all params to be passed to postscript ... # ] # } # # # errorcode policy # errorcode = 0 (CommonVariables.PrePost_PluginStatus_Successs), means success, script runs without error, warnings maybe possible # errorcode = 5 (CommonVariables.PrePost_PluginStatus_Timeout), means timeout # errorcode = 10 (CommonVariables.PrePost_PluginStatus_ConfigNotFound), config file not found # errorcode = process return code, means bash script encountered some other error, like 127 for script not found class ScriptRunnerResult(object): def __init__(self): self.errorCode = None self.continueBackup = True self.noOfRetries = 0 self.requiredNoOfRetries = 0 self.fileCode = [] self.filePath = [] def __str__(self): errorStr = 'ErrorCode :- ' + str(self.errorCode) + '\n' errorStr += 'Continue Backup :- ' + str(self.continueBackup) + '\n' errorStr += 'Number of Retries done :- ' + str(self.noOfRetries) + '\n' return errorStr class ScriptRunner(object): """ description of class """ def __init__(self, logger, name, configPath, maxTimeOut): self.logger = logger self.timeoutInSeconds = 10 self.pollSleepTime = 3 self.pollTotalCount = (self.timeoutInSeconds / self.pollSleepTime) self.configLocation = configPath self.pluginName = name self.continueBackupOnFailure = True self.preScriptParams = [] self.postScriptParams = [] self.preScriptLocation = None self.postScriptLocation = None self.preScriptNoOfRetries = 0 self.postScriptNoOfRetries = 0 self.fsFreeze_on = True self.configLoaded = False self.PreScriptCompletedSuccessfully = False self.maxTimeOut = maxTimeOut def get_config(self): """ Get configuration information from config.json """ try: with open(self.configLocation, 'r') as configFile: configData = json.load(configFile) configDataKeys = configData.keys() if 'timeoutInSeconds' in configDataKeys: self.timeoutInSeconds = min(configData['timeoutInSeconds'],self.maxTimeOut) if 'pluginName' in configDataKeys: self.pluginName = configData['pluginName'] if 'appName' in configDataKeys: HandlerUtil.HandlerUtility.add_to_telemetery_data('appName',configData['appName']) self.preScriptLocation = configData['preScriptLocation'] self.postScriptLocation = configData['postScriptLocation'] if 'preScriptParams' in configDataKeys: self.preScriptParams = configData['preScriptParams'] if 'postScriptParams' in configDataKeys: self.postScriptParams = configData['postScriptParams'] if 'continueBackupOnFailure' in configDataKeys: self.continueBackupOnFailure = configData['continueBackupOnFailure'] if 'preScriptNoOfRetries' in configDataKeys: self.preScriptNoOfRetries = configData['preScriptNoOfRetries'] if 'postScriptNoOfRetries' in configDataKeys: self.postScriptNoOfRetries = configData['postScriptNoOfRetries'] if 'fsFreezeEnabled' in configDataKeys: self.fsFreeze_on = configData['fsFreezeEnabled'] if 'ScriptsExecutionPollTimeSeconds' in configDataKeys and int(configData['ScriptsExecutionPollTimeSeconds']) >= 1 and int(configData['ScriptsExecutionPollTimeSeconds']) <=5: self.pollSleepTime = int(configData['ScriptsExecutionPollTimeSeconds']) self.pollTotalCount = (self.timeoutInSeconds / self.pollSleepTime) self.configLoaded = True except IOError: errMsg = 'Error in opening ' + self.pluginName + ' config file.' + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') except ValueError as err: errMsg = 'Error in decoding ' + self.pluginName + ' config file.' + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') except KeyError as err: errMsg = 'Error in fetching value for the key '+str(err) + ' in ' +self.pluginName+' config file.' + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') def find_owner(self, filename): file_owner = '' try: file_owner = getpwuid(os.stat(filename).st_uid).pw_name except Exception as err: errMsg = 'Error in fetching owner of the file : ' + filename + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') return file_owner def validate_permissions(self, filename): valid_permissions = True try: permissions = oct(os.stat(filename)[ST_MODE])[-3:] self.logger.log('Permissions of the file ' + filename + ' are ' + permissions,True) if int(permissions[1]) > 0 : #validating permissions for group valid_permissions = False if int(permissions[2]) > 0 : #validating permissions for others valid_permissions = False except Exception as err: errMsg = 'Error in fetching permissions of the file : ' + filename + ': %s, stack trace: %s' % (str(err), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') valid_permissions = False return valid_permissions def validate_scripts(self): errorCode = CommonVariables.PrePost_PluginStatus_Success dobackup = True self.get_config() self.logger.log('Plugin:'+str(self.pluginName)+' timeout:'+str(self.timeoutInSeconds)+' pollTotalCount:'+str(self.pollTotalCount) +' preScriptParams:'+str(self.preScriptParams)+' postScriptParams:' + str(self.postScriptParams)+ ' continueBackupOnFailure:' + str(self.continueBackupOnFailure) + ' preScriptNoOfRetries:' + str(self.preScriptNoOfRetries) + ' postScriptNoOfRetries:' + str(self.postScriptNoOfRetries) + ' Global FS Freeze on :' + str(self.fsFreeze_on), True, 'Info') if not self.configLoaded: errorCode = CommonVariables.FailedPrepostPluginConfigParsing self.logger.log('Cant run prescript for '+self.pluginName+' . Config File error.', True, 'Error') return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime dobackup = self.continueBackupOnFailure if not os.path.isfile(self.preScriptLocation): self.logger.log('Prescript file does not exist in the location '+self.preScriptLocation, True, 'Error') errorCode = CommonVariables.FailedPrepostPreScriptNotFound return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime if not self.validate_permissions(self.preScriptLocation): self.logger.log('Prescript file does not have desired permissions ', True, 'Error') errorCode = CommonVariables.FailedPrepostPreScriptPermissionError return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime if not self.find_owner(self.preScriptLocation) == 'root': self.logger.log('The owner of the PreScript file ' + self.preScriptLocation + ' is ' + self.find_owner(self.preScriptLocation) + ' but not root', True, 'Error') errorCode = CommonVariables.FailedPrepostPreScriptPermissionError return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime if not os.path.isfile(self.postScriptLocation): self.logger.log('Postscript file does not exist in the location ' + self.postScriptLocation, True, 'Error') errorCode = CommonVariables.FailedPrepostPostScriptNotFound return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime if not self.validate_permissions(self.postScriptLocation): self.logger.log('Postscript file does not have desired permissions ', True, 'Error') errorCode = CommonVariables.FailedPrepostPostScriptPermissionError return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime if not self.find_owner(self.postScriptLocation) == 'root': self.logger.log('The owner of the PostScript file ' + self.postScriptLocation + ' is '+ self.find_owner(self.postScriptLocation) + ' but not root', True, 'Error') errorCode = CommonVariables.FailedPrepostPostScriptPermissionError return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime return errorCode,dobackup,self.fsFreeze_on, self.pollSleepTime def pre_script(self, pluginIndex, preScriptCompleted, preScriptResult): # Generates a system call to run the prescript # -- pluginIndex is the index for the current plugin assigned by pluginHost # -- preScriptCompleted is a bool array, upon completion of script, true will be assigned at pluginIndex # -- preScriptResult is an array and it stores the result at pluginIndex result = ScriptRunnerResult() result.requiredNoOfRetries = self.preScriptNoOfRetries paramsStr = ['sh',str(self.preScriptLocation)] for param in self.preScriptParams: paramsStr.append(str(param)) self.logger.log('Running prescript for '+self.pluginName+' module...',True,'Info') process = subprocess.Popen(paramsStr, stdout=subprocess.PIPE, stderr=subprocess.PIPE) flag_timeout = False curr = 0 cnt = 0 while True: while process.poll() is None: if curr >= self.pollTotalCount: self.logger.log('Prescript for '+self.pluginName+' timed out.',True,'Error') flag_timeout = True break curr = curr + 1 time.sleep(self.pollSleepTime) if process.returncode is CommonVariables.PrePost_ScriptStatus_Success: break if flag_timeout: break if cnt >= self.preScriptNoOfRetries: break self.logger.log('Prescript for '+self.pluginName+' failed. Retrying...',True,'Info') cnt = cnt + 1 result.noOfRetries = cnt if not flag_timeout: result.errorCode = process.returncode if result.errorCode != CommonVariables.PrePost_ScriptStatus_Success: self.logger.log('Prescript for '+self.pluginName+' failed with error code: '+str(result.errorCode)+' .',True,'Error') result.continueBackup = self.continueBackupOnFailure result.errorCode = CommonVariables.FailedPrepostPreScriptFailed else: self.PreScriptCompletedSuccessfully = True self.logger.log('Prescript for '+self.pluginName+' successfully executed.',True,'Info') else: result.errorCode = CommonVariables.FailedPrepostPreScriptTimeout result.continueBackup = self.continueBackupOnFailure preScriptCompleted[pluginIndex] = True preScriptResult[pluginIndex] = result def post_script(self, pluginIndex, postScriptCompleted, postScriptResult): # Generates a system call to run the postscript # -- pluginIndex is the index for the current plugin assigned by pluginHost # -- postScriptCompleted is a bool array, upon completion of script, true will be assigned at pluginIndex # -- postScriptResult is an array and it stores the result at pluginIndex result = ScriptRunnerResult() result.requiredNoOfRetries = self.postScriptNoOfRetries paramsStr = ['sh',str(self.postScriptLocation)] for param in self.postScriptParams: paramsStr.append(str(param)) self.logger.log('Running postscript for '+self.pluginName+' module...',True,'Info') process = subprocess.Popen(paramsStr, stdout=subprocess.PIPE, stderr=subprocess.PIPE) flag_timeout = False curr = 0 cnt = 0 while True: while process.poll() is None: if curr >= self.pollTotalCount: self.logger.log('Postscript for '+self.pluginName+' timed out.',True,'Error') flag_timeout = True break curr = curr + 1 time.sleep(self.pollSleepTime) if process.returncode is CommonVariables.PrePost_ScriptStatus_Success: break if flag_timeout: break if cnt >= self.postScriptNoOfRetries: break self.logger.log('Postscript for '+self.pluginName+' failed. Retrying...',True,'Info') cnt = cnt + 1 result.noOfRetries = cnt if not flag_timeout: result.errorCode = process.returncode if result.errorCode != CommonVariables.PrePost_ScriptStatus_Success: self.logger.log('Postscript for '+self.pluginName+' failed with error code: '+str(result.errorCode)+' .',True,'Error') result.errorCode = CommonVariables.FailedPrepostPostScriptFailed result.continueBackup = self.continueBackupOnFailure else: self.logger.log('Postscript for '+self.pluginName+' successfully executed.',True,'Info') else: result.errorCode = CommonVariables.FailedPrepostPostScriptTimeout result.continueBackup = self.continueBackupOnFailure postScriptCompleted[pluginIndex] = True postScriptResult[pluginIndex] = result ================================================ FILE: VMBackup/main/Utils/DiskUtil.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import os import os.path import shlex import sys from subprocess import * import shutil import uuid import glob from common import DeviceItem import Utils.HandlerUtil import traceback try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers class DiskUtil(object): __instance__ = None patching = None logger = None mount_output = None def __init__(self, patching, logger): if DiskUtil.__instance__ is None: self.patching = patching self.logger = logger self.mount_output = None DiskUtil.__instance__ = self else: return DiskUtil.__instance__ @staticmethod def get_instance(patching, logger): if not DiskUtil.__instance__: DiskUtil(patching, logger) return DiskUtil.__instance__ def get_device_items_property(self, lsblk_path, dev_name, property_name): get_property_cmd = lsblk_path + " /dev/" + dev_name + " -b -nl -o NAME," + property_name get_property_cmd_args =Utils.HandlerUtil.HandlerUtility.split(self.logger, get_property_cmd) get_property_cmd_p = Popen(get_property_cmd_args,stdout=subprocess.PIPE,stderr=subprocess.PIPE) output,err = get_property_cmd_p.communicate() output= str(output) lines = output.splitlines() for i in range(0,len(lines)): item_value_str = lines[i].strip() if(item_value_str != ""): disk_info_item_array =Utils.HandlerUtil.HandlerUtility.split(self.logger, item_value_str) if(dev_name == disk_info_item_array[0]): if(len(disk_info_item_array) > 1): return disk_info_item_array[1] return None def get_device_items_sles(self,dev_path): self.logger.log("get_device_items_sles : getting the blk info from " + str(dev_path), True) device_items = [] #first get all the device names if(dev_path is None): get_device_cmd = self.patching.lsblk_path + " -b -nl -o NAME" else: get_device_cmd = self.patching.lsblk_path + " -b -nl -o NAME " + dev_path get_device_cmd_args =Utils.HandlerUtil.HandlerUtility.split(self.logger, get_device_cmd) p = Popen(get_device_cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out_lsblk_output, err = p.communicate() out_lsblk_output = str(out_lsblk_output) lines = out_lsblk_output.splitlines() for i in range(0,len(lines)): item_value_str = lines[i].strip() if(item_value_str != ""): disk_info_item_array =Utils.HandlerUtil.HandlerUtility.split(self.logger, item_value_str) device_item = DeviceItem() device_item.name = disk_info_item_array[0] device_items.append(device_item) for i in range(0,len(device_items)): device_item = device_items[i] device_item.file_system = self.get_device_items_property(lsblk_path=self.patching.lsblk_path,dev_name=device_item.name,property_name='FSTYPE') device_item.mount_point = self.get_device_items_property(lsblk_path=self.patching.lsblk_path,dev_name=device_item.name,property_name='MOUNTPOINT') device_item.label = self.get_device_items_property(lsblk_path=self.patching.lsblk_path,dev_name=device_item.name,property_name='LABEL') device_item.uuid = self.get_device_items_property(lsblk_path=self.patching.lsblk_path,dev_name=device_item.name,property_name='UUID') #get the type of device model_file_path = '/sys/block/' + device_item.name + '/device/model' if(os.path.exists(model_file_path)): with open(model_file_path,'r') as f: device_item.model = f.read().strip() if(device_item.model == 'Virtual Disk'): self.logger.log("model is virtual disk", True) device_item.type = 'disk' if(device_item.type != 'disk'): partition_files = glob.glob('/sys/block/*/' + device_item.name + '/partition') if(partition_files is not None and len(partition_files) > 0): self.logger.log("partition files exists", True) device_item.type = 'part' return device_items def get_device_items_from_lsblk_list(self, lsblk_path, dev_path): self.logger.log("get_device_items_from_lsblk_list : getting the blk info from " + str(dev_path), True) device_items = [] #first get all the device names if(dev_path is None): get_device_cmd = lsblk_path + " -b -nl -o NAME" else: get_device_cmd = lsblk_path + " -b -nl -o NAME " + dev_path get_device_cmd_args =Utils.HandlerUtil.HandlerUtility.split(self.logger, get_device_cmd) p = Popen(get_device_cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out_lsblk_output, err = p.communicate() if sys.version_info > (3,): out_lsblk_output =str(out_lsblk_output, encoding='utf-8', errors="backslashreplace") else: out_lsblk_output =str(out_lsblk_output) lines = out_lsblk_output.splitlines() device_items_temp = [] for i in range(0,len(lines)): item_value_str = lines[i].strip() if(item_value_str != ""): disk_info_item_array =Utils.HandlerUtil.HandlerUtility.split(self.logger, item_value_str) device_item = DeviceItem() device_item.name = disk_info_item_array[0] device_items_temp.append(device_item) for i in range(0,len(device_items_temp)): device_item = device_items_temp[i] device_item.mount_point = self.get_device_items_property(lsblk_path=lsblk_path,dev_name=device_item.name,property_name='MOUNTPOINT') if (device_item.mount_point is not None): device_item.file_system = self.get_device_items_property(lsblk_path=lsblk_path,dev_name=device_item.name,property_name='FSTYPE') device_item.label = self.get_device_items_property(lsblk_path=lsblk_path,dev_name=device_item.name,property_name='LABEL') device_item.uuid = self.get_device_items_property(lsblk_path=lsblk_path,dev_name=device_item.name,property_name='UUID') device_item.type = self.get_device_items_property(lsblk_path=lsblk_path,dev_name=device_item.name,property_name='TYPE') device_items.append(device_item) self.logger.log("lsblk MOUNTPOINT=" + str(device_item.mount_point) + ", NAME=" + str(device_item.name) + ", TYPE=" + str(device_item.type) + ", FSTYPE=" + str(device_item.file_system) + ", LABEL=" + str(device_item.label) + ", UUID=" + str(device_item.uuid) + ", MODEL=" + str(device_item.model), True) return device_items def get_lsblk_pairs_output(self, lsblk_path, dev_path): self.logger.log("get_lsblk_pairs_output : getting the blk info from " + str(dev_path) + " using lsblk_path " + str(lsblk_path), True) # If an alternate user is specified in vmbackup.conf, run lsblk command through that user, not with root access. # Fixes issues found in some SUSE-related distros where lsblk command gets stuck with root access # Sample vmbackup.conf file with such alternate user setting: # [lsblkUser] # username: vmadmin configfile = '/etc/azure/vmbackup.conf' command_user = '' alternate_user = False try : if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_option('lsblkUser','username'): lsblk_user = config.get('lsblkUser','username') command_user = "su - " + lsblk_user + " -c" if (dev_path is None): command_user = command_user + ' \'' + 'lsblk -b -n -P -o NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE' + '\'' else: command_user = command_user + ' \'' + 'lsblk -b -n -P -o NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE' + ' ' + dev_path + '\'' alternate_user = True except Exception as e: pass out_lsblk_output = None error_msg = None is_lsblk_path_wrong = False try: if (alternate_user): self.logger.log("Switching to alternate user to run this lsblk command: " + str(command_user), True) p = Popen(command_user, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) elif(dev_path is None): p = Popen([str(lsblk_path), '-b', '-n','-P','-o','NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) else: p = Popen([str(lsblk_path), '-b', '-n','-P','-o','NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE',dev_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE) except Exception as e: errMsg = 'Exception in lsblk command, error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') is_lsblk_path_wrong = True if is_lsblk_path_wrong == False : out_lsblk_output, err = p.communicate() if sys.version_info > (3,): out_lsblk_output = str(out_lsblk_output, encoding='utf-8', errors="backslashreplace") else: out_lsblk_output = str(out_lsblk_output) error_msg = str(err) if(error_msg is not None and error_msg.strip() != ""): self.logger.log(str(err), True) return is_lsblk_path_wrong, out_lsblk_output, error_msg def get_which_command_result(self, program_to_locate): self.logger.log("getting the which info for " + str(program_to_locate), True) out_which_output = None error_msg = None try: p = Popen(['which', str(program_to_locate)], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out_which_output, err = p.communicate() if sys.version_info > (3,): out_which_output = str(out_which_output, encoding='utf-8', errors="backslashreplace") else: out_which_output = str(out_which_output) error_msg = str(err) if(error_msg is not None and error_msg.strip() != ""): self.logger.log(str(err), True) self.logger.log("which command result :" + str(out_which_output), True) if (out_which_output is not None): out_which_output = out_which_output.splitlines()[0] except Exception as e: errMsg = 'Exception in which command, error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') return out_which_output, error_msg def get_device_items(self, dev_path): if(self.patching.distro_info[0].lower() == 'suse' and self.patching.distro_info[1] == '11'): return self.get_device_items_sles(dev_path) else: self.logger.log("getting the blk info from " + str(dev_path), True) device_items = [] lsblk_path = self.patching.lsblk_path # Get lsblk command output using lsblk_path as self.patching.lsblk_path is_lsblk_path_wrong, out_lsblk_output, error_msg = self.get_lsblk_pairs_output(lsblk_path, dev_path) # if lsblk_path was wrong, use /bin/lsblk or usr/bin/lsblk based on self.patching.usr_flag to get lsblk command output again for centos/redhat distros if (is_lsblk_path_wrong == True) and (self.patching.distro_info[0].lower() == 'centos' or self.patching.distro_info[0].lower() == 'redhat'): if self.patching.usr_flag == 1: self.logger.log("lsblk path is wrong, removing /usr prefix", True, 'Warning') lsblk_path = "/bin/lsblk" else: self.logger.log("lsblk path is wrong, adding /usr prefix", True, 'Warning') lsblk_path = "/usr/bin/lsblk" is_lsblk_path_wrong, out_lsblk_output, error_msg = self.get_lsblk_pairs_output(lsblk_path, dev_path) # if lsblk_path was still wrong, lsblk_path using "which" command if (is_lsblk_path_wrong == True): self.logger.log("lsblk path is wrong. finding path using which command", True, 'Warning') out_which_output, which_error_msg = self.get_which_command_result('lsblk') # get lsblk command output if (out_which_output is not None): lsblk_path = str(out_which_output) is_lsblk_path_wrong, out_lsblk_output, error_msg = self.get_lsblk_pairs_output(lsblk_path, dev_path) # if error_msg contains "invalid option" or "P" (rely on only "-P" optiont in error to handle non-English locales), then get device_items using method get_device_items_from_lsblk_list if (error_msg is not None and error_msg.strip() != "" and ('invalid option' in error_msg or 'P' in error_msg)): device_items = self.get_device_items_from_lsblk_list(lsblk_path, dev_path) # else get device_items from parsing the lsblk command output elif (out_lsblk_output is not None): lines = out_lsblk_output.splitlines() for i in range(0,len(lines)): item_value_str = lines[i].strip() if(item_value_str != ""): disk_info_item_array =Utils.HandlerUtil.HandlerUtility.split(self.logger, item_value_str) device_item = DeviceItem() disk_info_item_array_length = len(disk_info_item_array) for j in range(0, disk_info_item_array_length): disk_info_property = disk_info_item_array[j] property_item_pair = disk_info_property.split('=') if(property_item_pair[0] == 'NAME'): device_item.name = property_item_pair[1].strip('"') if(property_item_pair[0] == 'TYPE'): device_item.type = property_item_pair[1].strip('"') if(property_item_pair[0] == 'FSTYPE'): device_item.file_system = property_item_pair[1].strip('"') if(property_item_pair[0] == 'MOUNTPOINT'): device_item.mount_point = property_item_pair[1].strip('"') if(property_item_pair[0] == 'LABEL'): device_item.label = property_item_pair[1].strip('"') if(property_item_pair[0] == 'UUID'): device_item.uuid = property_item_pair[1].strip('"') if(property_item_pair[0] == 'MODEL'): device_item.model = property_item_pair[1].strip('"') self.logger.log("lsblk MOUNTPOINT=" + str(device_item.mount_point) + ", NAME=" + str(device_item.name) + ", TYPE=" + str(device_item.type) + ", FSTYPE=" + str(device_item.file_system) + ", LABEL=" + str(device_item.label) + ", UUID=" + str(device_item.uuid) + ", MODEL=" + str(device_item.model), True) if(device_item.mount_point is not None and device_item.mount_point != "" and device_item.mount_point != " "): device_items.append(device_item) return device_items def get_mount_command_output(self, mount_path): self.logger.log("getting the mount info using mount_path " + str(mount_path), True) out_mount_output = None error_msg = None is_mount_path_wrong = False try: p = Popen([str(mount_path)], stdout=subprocess.PIPE, stderr=subprocess.PIPE) except Exception as e: errMsg = 'Exception in mount command, error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') is_mount_path_wrong = True if is_mount_path_wrong == False : out_mount_output, err = p.communicate() if sys.version_info > (3,): out_mount_output = str(out_mount_output, encoding='utf-8', errors="backslashreplace") else: out_mount_output = str(out_mount_output) self.logger.log("getting the mount info using mount_path " + out_mount_output, True) error_msg = str(err) if(error_msg is not None and error_msg.strip() != ""): self.logger.log(str(err), True) return is_mount_path_wrong, out_mount_output, error_msg def get_mount_points(self): mount_points_info = [] mount_points = [] fs_types = [] out_mount_output = self.get_mount_output() if (out_mount_output is not None): #Extract the list of mnt_point in order lines = out_mount_output.splitlines() #Reverse the mount command output to go-through from last-to-first mounts in output lines.reverse() for line in lines: line = line.strip() if(line != ""): deviceName =Utils.HandlerUtil.HandlerUtility.split(self.logger, line)[0] mountPrefixStr = " on /" prefixIndex = line.find(mountPrefixStr) if(prefixIndex >= 0): mountpointStart = prefixIndex + len(mountPrefixStr) - 1 fstypePrefixStr = " type " mountpointEnd = line.find(fstypePrefixStr, mountpointStart) if(mountpointEnd >= 0): mount_point = line[mountpointStart:mountpointEnd] fs_type = "" fstypeStart = line.find(fstypePrefixStr) + len(fstypePrefixStr) - 1 if(line.find(fstypePrefixStr) >= 0): fstypeEnd = line.find(" ", fstypeStart+1) if(fstypeEnd >=0): fs_type = line[fstypeStart+1:fstypeEnd] # If there is a duplicate, keep only the first instance if (mount_point not in mount_points): self.logger.log("mount command, adding mount :" + str(mount_point) + ": device :" + str(deviceName) + ": fstype :"+ str(fs_type) + ":", True) fs_types.append(fs_type) mount_points.append(mount_point) mount_points_info.append((mount_point,deviceName,fs_type)) else: self.logger.log("####### mount command, not adding duplicate mount :" + str(mount_point) + ": device :" + str(deviceName) + ": fstype :"+ str(fs_type) + ":", True) #Now reverse the mount_points & fs_types lists to make them in the same order as mount command output order mount_points_info.reverse() mount_points.reverse() for fstype in fs_types: if ("fuse" in fstype.lower() or "nfs" in fstype.lower() or "cifs" in fstype.lower()): Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("networkFSTypePresentInMount","True") break return mount_points,mount_points_info def get_mount_file_systems(self): out_mount_output = self.get_mount_output() file_systems_info = [] mount_points = [] if (out_mount_output is not None): lines = out_mount_output.splitlines() #Reverse the mount command output to go-through from last-to-first mounts in output lines.reverse() for line in lines: self.logger.log("print line by line :" + line , True) line = line.strip() if(line != ""): file_system =Utils.HandlerUtil.HandlerUtility.split(self.logger, line)[0] mountPrefixStr = " on /" prefixIndex = line.find(mountPrefixStr) if(prefixIndex >= 0): mountpointStart = prefixIndex + len(mountPrefixStr) - 1 fstypePrefixStr = " type " mountpointEnd = line.find(fstypePrefixStr, mountpointStart) if(mountpointEnd >= 0): mount_point = line[mountpointStart:mountpointEnd] fs_type = "" fstypeStart = line.find(fstypePrefixStr) + len(fstypePrefixStr) - 1 if(line.find(fstypePrefixStr) >= 0): fstypeEnd = line.find(" ", fstypeStart+1) if(fstypeEnd >=0): fs_type = line[fstypeStart+1:fstypeEnd] # If there is a duplicate, keep only the first instance if (mount_point not in mount_points): file_systems_info.append((file_system,fs_type,mount_point)) mount_points.append(mount_point) #Now reverse the file_systems_info list to make them in the same order as mount command output order file_systems_info.reverse() return file_systems_info def get_mount_output(self): if self.mount_output is not None: return self.mount_output else : # Get the output on the mount command self.logger.log("getting the mount-points info using mount command ", True) mount_path = self.patching.mount_path is_mount_path_wrong, out_mount_output, error_msg = self.get_mount_command_output(mount_path) if (is_mount_path_wrong == True): if self.patching.usr_flag == 1: self.logger.log("mount path is wrong.removing /usr prefix", True, 'Warning') mount_path = "/bin/mount" else: self.logger.log("mount path is wrong.Adding /usr prefix", True, 'Warning') mount_path = "/usr/bin/mount" is_mount_path_wrong, out_mount_output, error_msg = self.get_mount_command_output(mount_path) # if mount_path was still wrong, mount_path using "which" command if (is_mount_path_wrong == True): self.logger.log("mount path is wrong. finding path using which command", True, 'Warning') out_which_output, which_error_msg = self.get_which_command_result('mount') # get mount command output if (out_which_output is not None): mount_path = str(out_which_output) is_mount_path_wrong, out_mount_output, error_msg = self.get_mount_command_output(mount_path) self.mount_output = out_mount_output return out_mount_output ================================================ FILE: VMBackup/main/Utils/Event.py ================================================ from datetime import datetime import os import sys if sys.version_info[0] == 3: import threading else: # to make it compatible with python version less than 3 import thread as threading class Event: ''' The agent will only pick the first 3K - 3072 characters. Rest of the characters would be discarded from the messages. To ensure this we Check the message length and divide them accordingly into chunks of characters less than 3K. ''' def __init__(self, level, message, task_name, operation_id, version): self.version = version self.timestamp = datetime.utcnow().isoformat() self.task_name = task_name self.event_level = level self.message = message self.event_pid = str(os.getpid()) self.event_tid = str(threading.get_ident()).zfill(8) self.operation_id = operation_id def convertToDictionary(self): return dict(Version = self.version, Timestamp = self.timestamp, TaskName = self.task_name, EventLevel = self.event_level, Message = self.message, EventPid = self.event_pid, EventTid = self.event_tid, OperationId = str(self.operation_id)) ================================================ FILE: VMBackup/main/Utils/EventLoggerUtil.py ================================================ import os import threading import json import sys import datetime import time import uuid if sys.version_info[0] == 2: import Queue as queue else: # if python version is > 3 import queue import shutil from Utils.LogHelper import FileHelpers,LoggingConstants from Utils.StringHelper import StringHelper from Utils.Event import Event class EventLogger: _instance = None _lock = threading.Lock() def __init__(self, event_directory, severity_level, use_async_event_logging = 0): global logger self.temporary_directory = os.path.join(event_directory, 'Temp') self.space_available_in_event_directory = 0 self.event_processing_interval = 0 self.disposed = False self.event_processing_task = None self.current_message_len = 0 self.event_logging_enabled = False self.event_logging_error_count = 0 self.events_folder = event_directory self.event_logging_enabled = bool(self.events_folder) self.async_event_logging = use_async_event_logging self.filehelper = FileHelpers() if self.event_logging_enabled: self.extension_version = os.path.basename(os.getcwd()) self.operation_id = uuid.UUID(int=0) self.log_severity_level = severity_level logger.log("Information: EventLogging severity level setting is {0}".format(self.log_severity_level)) # creating a temp directory if not os.path.exists(self.temporary_directory): os.makedirs(self.temporary_directory) FileHelpers.clearOldJsonFilesInDirectory(self.temporary_directory) FileHelpers.clearOldJsonFilesInDirectory(self.events_folder) self.current_message = '' self.event_queue = queue.Queue() space_available = LoggingConstants.MaxEventDirectorySize - FileHelpers.getSizeOfDir(self.events_folder) self.space_available_in_event_directory = max(0, space_available) print("Information: Space available in event directory : %sB" %(self.space_available_in_event_directory)) if( self.async_event_logging == 1): self.event_processing_signal = threading.Event() # an event object that runs continuously until signal is set self.event_processing_interval = LoggingConstants.MinEventProcesingInterval print("Information: Setting event reporting interval to %ss" %(self.event_processing_interval)) self.begin_event_queue_polling() self._event_processing_loop else: print("Warning: EventsFolder parameter is empty. Guest Agent does not support event logging.") @staticmethod def GetInstance(backup_logger, event_directory, severity_level, use_async_event_logging = 0): global logger try: logger = backup_logger if EventLogger._instance is None: with EventLogger._lock: if EventLogger._instance is None: EventLogger._instance = EventLogger(event_directory, severity_level, use_async_event_logging) except Exception as e: print("Exception has occurred {0}".format(str(e))) return EventLogger._instance def update_properties(self, task_id): self.operation_id = task_id def severity(self, severity_level): level = 0 if(severity_level == "Verbose"): level = 0 elif(severity_level == "Info"): level = 1 elif(severity_level == "Warning"): level = 2 else: level = 3 return level def trace_message(self, severity_level, message): global logger level = self.severity(severity_level) if self.event_logging_enabled and level >= self.log_severity_level: stringhelper = StringHelper() message = stringhelper.resolve_string(severity_level, message) try: message_len = len(message) message_max_len = LoggingConstants.MaxMessageLenLimit if message_len > message_max_len: num_chunks = (message_len + message_max_len - 1) // message_max_len msg_date_time = datetime.datetime.utcnow().strftime(u'%Y-%m-%dT%H:%M:%S.%fZ') for string_part in range(num_chunks): start_index = string_part * message_max_len length = min(message_max_len, message_len - start_index) message_part = '%s [%d/%d] %s' % (msg_date_time, string_part + 1, num_chunks, message[start_index:start_index+length]) self.log_event(message_part) else: self.log_event(message) except Exception as ex: self.event_logging_error_count += 1 if self.event_logging_error_count > 10: self.event_logging_enabled = False print("Warning: Count(EventLoggingErrors) > 10. Disabling eventLogging. Continue with execution") print("Exception: {0}" .format(str(ex))) def log_event(self, message): global logger try: if self.current_message_len + len(message) > LoggingConstants.MaxMessageLengthPerEvent: self.event_queue.put(Event("Info", self.current_message, LoggingConstants.DefaultEventTaskName, self.operation_id, self.extension_version).convertToDictionary()) # Reset the current message self.current_message = message self.current_message_len = len(message) else: self.current_message += message self.current_message_len += len(message) except Exception as ex: print("Warning: Error adding extension event to queue. Exception: {0}" .format(str(ex))) def begin_event_queue_polling(self): global logger print("Event logging via polling is starting...") #using threads try: self.event_processing_task = threading.Thread(target=self._event_processing_loop) self.event_processing_task.start() except Exception as e: print("Exception in begin_event_queue_polling {0}".format(str(e))) def _event_processing_loop(self): global logger if(self.async_event_logging == 1): while not self.event_processing_signal.wait(self.event_processing_interval): try: self._process_events() except Exception as ex: print("Warning: Event processing has failed. Exception: {0}" .format(str(ex))) else: try: self._process_events() except Exception as ex: print("Warning: Event processing has failed. Exception: {0}" .format(str(ex))) print("Information: Exiting function polling...") def _process_events(self): global logger try: if self.space_available_in_event_directory == 0: # There is no space available in the events directory then a check is made to see if space has been # created (no files). If there is space available we reset our flags and proceed with processing. if not os.listdir(self.events_folder): self.space_available_in_event_directory = LoggingConstants.MaxEventDirectorySize logger.log("Event directory has space for new event files. Resuming event reporting.") else: self.event_queue = queue.Queue() return if not self.event_queue.empty(): if sys.version_info[0] == 2: event_file_path = os.path.join(self.temporary_directory, "{}.json".format(int(time.time() * 1000000000))) else: event_file_path = os.path.join(self.temporary_directory, "{}.json".format(int(datetime.datetime.utcnow().timestamp() * 1000000000))) with self._create_event_file(event_file_path) as file: if file is None: logger.log("Warning: Could not create the event file in the path mentioned.") return print("Clearing out event queue for processing...") old_queue = self.event_queue self.event_queue = queue.Queue() self._write_events_to_event_file(file, old_queue, event_file_path) self._send_event_file_to_event_directory(event_file_path, self.events_folder) except Exception as e: print("Exception occurred in _process_events {0}".format(str(e))) def _create_event_file(self, event_file_path): print("Information: Attempting to create a new event file...") success_msg = "Successfully created new event file: %s" % event_file_path retry_msg = "Failed to write events to file: %s. Retrying..." % event_file_path err_msg = "Failed to write events to file %s after %d attempts. No longer retrying. Events for this iteration will not be reported." % (event_file_path, LoggingConstants.MaxAttemptsForEventFileCreationWriteMove) stream_writer = self.filehelper.execute_with_retries( LoggingConstants.MaxAttemptsForEventFileCreationWriteMove, LoggingConstants.ThreadSleepDuration, success_msg, retry_msg, err_msg, lambda: open(event_file_path, "w") ) return stream_writer def _write_events_to_event_file(self, file, events, event_file_path): data_list = [] while not events.empty(): data = events.get() data_list.append(data) json_data = json.dumps(data_list) if not json_data: print("Warning: Unable to serialize events. Events for this iteration will not be reported.") return success_msg = "Successfully wrote events to file: %s" % event_file_path retry_msg = "Failed to write events to file: %s. Retrying..." % event_file_path err_msg = "Failed to write events to file %s after %d attempts. No longer retrying. Events for this iteration will not be reported." % (event_file_path, LoggingConstants.MaxAttemptsForEventFileCreationWriteMove) self.filehelper.execute_with_retries( LoggingConstants.MaxAttemptsForEventFileCreationWriteMove, LoggingConstants.ThreadSleepDuration, success_msg, retry_msg, err_msg, lambda: file.write(json_data) ) def _send_event_file_to_event_directory(self, file_path, events_folder): file_info = os.stat(file_path) file_size = file_info.st_size if self.space_available_in_event_directory - file_size >= 0: new_path_for_event_file = os.path.join(events_folder, os.path.basename(file_path)) success_msg = "Successfully moved event file to event directory: %s" % new_path_for_event_file retry_msg = "Unable to move event file to event directory: %s. Retrying..." % file_path err_msg = "Unable to move event file to event directory: %s . No longer retrying. Events for this iteration will not be reported." % file_path self.filehelper.execute_with_retries( LoggingConstants.MaxAttemptsForEventFileCreationWriteMove, LoggingConstants.ThreadSleepDuration, success_msg, retry_msg, err_msg, lambda: shutil.move(file_path, new_path_for_event_file) ) self.space_available_in_event_directory -= file_size else: self.space_available_in_event_directory = 0 FileHelpers.deleteFile(file_path) print("Information: Event reporting has paused due to reaching maximum capacity in the Event directory. Reporting will resume once space is available. Events for this iteration will not be reported.") def clear_temp_directory(self, directory_path): try: if os.path.exists(directory_path): if len(os.listdir(directory_path)) == 0: os.rmdir(directory_path) else: shutil.rmtree(directory_path) except Exception as ex: print("Warning: Error clearing the temp directory. Exception: {0}".format(str(ex))) def dispose(self): print("Information: Dispose(), called on EventLogger. Event processing is terminating...") self._dispose(True) def _dispose(self, disposing): global logger try: if not self.disposed: if disposing and self.event_logging_enabled: if self.async_event_logging == 1: self.event_processing_signal.set() self.event_processing_task.join() self.event_processing_signal.clear() if (self.current_message != ''): self.event_queue.put(Event("Info", self.current_message, LoggingConstants.DefaultEventTaskName, self.operation_id, self.extension_version).convertToDictionary()) if not self.event_queue.empty(): try: self._process_events() self.current_message = '' self.dispose() except Exception as ex: logger.log("Warning: Unable to process events before termination of extension. Exception: {0}" .format(str(ex))) self.disposed = True print("Information: Event Logger has terminated") print("Clearing the temp directory") self.clear_temp_directory(self.temporary_directory) self.event_logging_enabled = False except Exception as ex: print("Warning: Processing Dispose() of EventLogger resulted in Exception: {0}" .format(str(ex))) ================================================ FILE: VMBackup/main/Utils/HandlerUtil.py ================================================ # # Handler library for Linux IaaS # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "", "eventFolder": "", "configFolder": "", "statusFolder": "", "heartbeatFile": "", } }] Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Example Status Report: [{"version":"1.0","timestampUTC":"2014-05-29T04:20:13Z","status":{"name":"Chef Extension Handler","operation":"chef-client-run","status":"success","code":0,"formattedMessage":{"lang":"en-US","message":"Chef-client run success"}}}] """ import os import os.path import shlex import sys import re try: import imp as imp except ImportError: import importlib as imp import base64 import json import tempfile import time from os.path import join import Utils.WAAgentUtil from Utils.WAAgentUtil import waagent import logging import logging.handlers try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers from common import CommonVariables import platform import subprocess import datetime import Utils.Status from Utils.EventLoggerUtil import EventLogger from Utils.LogHelper import LoggingLevel, LoggingConstants, FileHelpers # Handle the deprecation of platform.dist() in Python 3.8+ try: import distro HAS_DISTRO = True except ImportError: HAS_DISTRO = False from MachineIdentity import MachineIdentity import ExtensionErrorCodeHelper import traceback DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ" class HandlerContext: def __init__(self,name): self._name = name self._version = '0.0' return class HandlerUtility: telemetry_data = {} serializable_telemetry_data = [] ExtErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.success SnapshotConsistency = Utils.Status.SnapshotConsistencyType.none HealthStatusCode = -1 def __init__(self, log, error, short_name): self._log = log self._error = error self.log_message = "" self._short_name = short_name self.patching = None self.storageDetailsObj = None self.partitioncount = 0 self.logging_file = None self.pre_post_enabled = False self.severity_level = self.get_severity_level() self.event_dir = None self.eventlogger = None self.operation = None def _get_log_prefix(self): return '[%s-%s]' % (self._context._name, self._context._version) # Look through all .settings files in the config folder and, # Retrieve the most recent modified file's seq# def _get_current_seq_no(self, config_folder): seq_no = -1 cur_seq_no = -1 freshest_time = None for subdir, dirs, files in os.walk(config_folder): for file in files: try: if(file.endswith('.settings')): cur_seq_no = int(os.path.basename(file).split('.')[0]) if(freshest_time == None): freshest_time = os.path.getmtime(join(config_folder,file)) seq_no = cur_seq_no else: current_file_m_time = os.path.getmtime(join(config_folder,file)) if(current_file_m_time > freshest_time): freshest_time = current_file_m_time seq_no = cur_seq_no except ValueError: continue return seq_no def get_last_seq(self): if(os.path.isfile('mrseq')): seq = waagent.GetFileContents('mrseq') if(seq): return int(seq) return -1 def exit_if_same_seq(self): current_seq = int(self._context._seq_no) last_seq = self.get_last_seq() if(current_seq == last_seq): self.log("the sequence number are same, so skip, current:" + str(current_seq) + "== last:" + str(last_seq)) self.update_settings_file() if(self.eventlogger is not None): self.eventlogger.dispose() sys.exit(0) def set_event_logger(self, eventlogger): self.eventlogger = eventlogger def log(self, message,level='Info'): try: self.log_with_no_try_except(message, level) except IOError: pass except Exception as e: try: errMsg = str(e) + 'Exception in hutil.log' self.log_with_no_try_except(errMsg, 'Warning') except Exception as e: pass def log_with_no_try_except(self, message, level='Info'): WriteLog = self.get_strvalue_from_configfile('WriteLog','True') if (WriteLog == None or WriteLog == 'True'): if sys.version_info > (3,): if self.logging_file is not None: self.log_py3(message) if self.eventlogger != None: self.eventlogger.trace_message(level, message) else: pass else: self._log(self._get_log_prefix() + message) if self.eventlogger != None: self.eventlogger.trace_message(level, message) message = "{0} {1} {2} \n".format(str(datetime.datetime.utcnow()) , level , message) self.log_message = self.log_message + message def log_py3(self, msg): if type(msg) is not str: msg = str(msg, errors="backslashreplace") msg = str(datetime.datetime.utcnow()) + " " + str(self._get_log_prefix()) + msg + "\n" try: with open(self.logging_file, "a+") as C : C.write(msg) except IOError: pass def error(self, message): self._error(self._get_log_prefix() + message) def fetch_log_message(self): return self.log_message def _decrypt_protected_settings(self, encrypted_file, cert_path, pkey_path): """ Decrypt protected settings with FIPS 140-3 AES256 support and defensive fallback. For FIPS 140-3 compliance, CRP is upgrading encryption to AES256. Opted-in VMs receive protected settings encrypted with AES256, while other VMs continue using DES_EDE3_CBC. The 'cms' command supports both AES256 and DES_EDE3_CBC encryption, while 'smime' only supports DES_EDE3_CBC. We try 'cms' first and fallback to 'smime' for compatibility. Args: encrypted_file: Path to temporary file containing encrypted settings cert_path: Path to certificate file (.crt) pkey_path: Path to private key file (.prv) Returns: Decrypted cleartext string """ cleartxt = None # Determine base64 decode command based on platform if 'NS-BSD' in platform.system(): # base64 tool is not available with NSBSD, use openssl base64_cmd = self.patching.openssl_path + " base64 -d -A -in " + encrypted_file else: base64_cmd = self.patching.base64_path + " -d " + encrypted_file # Try OpenSSL CMS command first (supports both AES256 and DES_EDE3_CBC) try: cms_cmd = base64_cmd + " | " + self.patching.openssl_path + " cms -inform DER -decrypt -recip " + cert_path + " -inkey " + pkey_path self.log("Attempting decryption using OpenSSL CMS command (supports AES256 and DES_EDE3_CBC)") result = waagent.RunGetOutput(cms_cmd, chk_err=False, log_cmd=False) if result[0] == 0 and result[1]: # Success (return code 0) and non-empty output cleartxt = result[1] self.log("Successfully decrypted protected settings using CMS command") return cleartxt else: self.log("CMS decryption failed with return code: " + str(result[0]) + ", attempting fallback to SMIME") except Exception as e: self.log("CMS decryption failed with exception: " + type(e).__name__ + ", attempting fallback to SMIME") # Fallback to OpenSSL SMIME command (supports DES_EDE3_CBC only) try: smime_cmd = base64_cmd + " | " + self.patching.openssl_path + " smime -inform DER -decrypt -recip " + cert_path + " -inkey " + pkey_path self.log("Attempting decryption using OpenSSL SMIME command (fallback - supports DES_EDE3_CBC only)") result = waagent.RunGetOutput(smime_cmd, chk_err=False, log_cmd=False) if result[0] == 0 and result[1]: # Success (return code 0) and non-empty output cleartxt = result[1] self.log("Successfully decrypted protected settings using SMIME command (fallback)") return cleartxt else: self.error("SMIME decryption also failed with return code: " + str(result[0])) except Exception as e: self.error("SMIME decryption failed with exception: " + type(e).__name__) # If both methods fail, raise an error if not cleartxt: self.error("Failed to decrypt protected settings using both CMS and SMIME commands") return cleartxt def _parse_config(self, ctxt): config = None try: config = json.loads(ctxt) except: self.error('JSON exception decoding settings file') if config == None: self.error('JSON error processing settings file') else: handlerSettings = config['runtimeSettings'][0]['handlerSettings'] if 'protectedSettings' in handlerSettings and \ "protectedSettingsCertThumbprint" in handlerSettings and \ handlerSettings['protectedSettings'] is not None and \ handlerSettings["protectedSettingsCertThumbprint"] is not None: protectedSettings = handlerSettings['protectedSettings'] thumb = handlerSettings['protectedSettingsCertThumbprint'] cert = waagent.LibDir + '/' + thumb + '.crt' pkey = waagent.LibDir + '/' + thumb + '.prv' f = tempfile.NamedTemporaryFile(delete=False) f.close() waagent.SetFileContents(f.name,config['runtimeSettings'][0]['handlerSettings']['protectedSettings']) cleartxt = None # Decrypt protected settings with FIPS 140-3 AES256 support and defensive fallback # Try cms command first (supports both AES256 and DES_EDE3_CBC), fallback to smime if needed cleartxt = self._decrypt_protected_settings(f.name, cert, pkey) jctxt = {} try: jctxt = json.loads(cleartxt) self.log('Config decoded correctly.') except: self.error('JSON exception decoding decrypted protected settings') handlerSettings['protectedSettings'] = jctxt # cleaning/removing the temp files created try: if os.path.isfile(f.name): os.remove(f.name) except Exception as e: self.log('Failed to remove the temporary file ' + str(e)) return config def do_parse_context(self, operation, seqNo): self.operation = operation _context = self.try_parse_context(seqNo) getWaagentPathUsed = Utils.WAAgentUtil.GetPathUsed() if(getWaagentPathUsed == 0): self.log("waagent old path is used") else: self.log("waagent new path is used") if not _context: self.log("maybe no new settings file found") if(self.eventlogger is not None): self.eventlogger.dispose() sys.exit(0) return _context def try_parse_context(self, seqNo): self._context = HandlerContext(self._short_name) handler_env = None config = None ctxt = None code = 0 try: self.log('try_parse_context : Sequence Number received ' + str(seqNo)) # get the HandlerEnvironment.json. According to the extension handler # spec, it is always in the ./ directory self.log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("Unable to locate " + handler_env_file) return None ctxt = waagent.GetFileContents(handler_env_file) if ctxt == None : self.error("Unable to read " + handler_env_file) try: handler_env = json.loads(ctxt) except: pass if handler_env == None : self.log("JSON error processing " + handler_env_file) return None if type(handler_env) == list: handler_env = handler_env[0] self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context._log_dir = handler_env['handlerEnvironment']['logFolder'] self._context._log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'],'extension.log') self.logging_file=self._context._log_file self._context._shell_log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'],'shell.log') self._change_log_file() try: if(self.get_intvalue_from_configfile("disable_logging", 0) == 0): self._context._event_dir = handler_env['handlerEnvironment']['eventsFolder'] self.event_dir = self._context._event_dir except Exception as e: self._context._event_dir = None self.event_dir = None errorMsg = 'The eventsFolder field is missing in handlerEnvironment.json file. Hence skipping event logging!' self.log(errorMsg, 'Error') self.log(repr(e), 'Error') self._context._status_dir = handler_env['handlerEnvironment']['statusFolder'] self._context._heartbeat_file = handler_env['handlerEnvironment']['heartbeatFile'] if seqNo != -1: self._context._seq_no = seqNo else: self._context._seq_no = self._get_current_seq_no(self._context._config_dir) if self._context._seq_no < 0: self.error("Unable to locate a .settings file!") return None self._context._seq_no = str(self._context._seq_no) if seqNo != -1: self.log('sequence number from environment variable is ' + self._context._seq_no) else: self.log('sequence number based on config file-names is ' + self._context._seq_no) self._context._status_file = os.path.join(self._context._status_dir, self._context._seq_no + '.status') self._context._settings_file = os.path.join(self._context._config_dir, self._context._seq_no + '.settings') self.log("setting file path is" + self._context._settings_file) ctxt = None ctxt = waagent.GetFileContents(self._context._settings_file) if ctxt == None : error_msg = 'Unable to read ' + self._context._settings_file + '. ' self.error(error_msg) return None else: if(self.operation is not None and self.operation.lower() == "enable"): # we should keep the current status file self.backup_settings_status_file(self._context._seq_no) self._context._config = self._parse_config(ctxt) except Exception as e: errorMsg = "Unable to parse context, error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.log(errorMsg, 'Error') raise return self._context def _change_log_file(self): self.log("Change log file to " + self._context._log_file) waagent.LoggerInit(self._context._log_file,'/dev/stdout') self._log = waagent.Log self._error = waagent.Error def save_seq(self): self.set_last_seq(self._context._seq_no) self.log("set most recent sequence number to " + self._context._seq_no) def set_last_seq(self,seq): waagent.SetFileContents('mrseq', str(seq)) ''' Sample /etc/azure/vmbackup.conf [SnapshotThread] seqsnapshot = 1 isanysnapshotfailed = False UploadStatusAndLog = True WriteLog = True onlyLocalFilesystems = True seqsnapshot valid values(0-> parallel snapshot, 1-> programatically set sequential snapshot , 2-> customer set it for sequential snapshot) ''' def get_value_from_configfile(self, key): global backup_logger value = None configfile = '/etc/azure/vmbackup.conf' try : if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_option('SnapshotThread',key): value = config.get('SnapshotThread',key) except Exception as e: pass return value def get_strvalue_from_configfile(self, key, default): value = self.get_value_from_configfile(key) if value == None or value == '': value = default try : value_str = str(value) except ValueError : self.log('Not able to parse the read value as string, falling back to default value', 'Warning') value = default return value def get_intvalue_from_configfile(self, key, default): value = default value = self.get_value_from_configfile(key) if value == None or value == '': value = default try : value_int = int(value) except ValueError : self.log('Not able to parse the read value as int, falling back to default value', 'Warning') value = default return int(value) def set_value_to_configfile(self, key, value): configfile = '/etc/azure/vmbackup.conf' try : self.log('setting ' + str(key) + 'in config file to ' + str(value) , 'Info') if not os.path.exists(os.path.dirname(configfile)): os.makedirs(os.path.dirname(configfile)) config = ConfigParsers.RawConfigParser() if os.path.exists(configfile): config.read(configfile) if config.has_section('SnapshotThread'): if config.has_option('SnapshotThread', key): config.remove_option('SnapshotThread', key) else: config.add_section('SnapshotThread') else: config.add_section('SnapshotThread') config.set('SnapshotThread', key, value) with open(configfile, 'w') as config_file: config.write(config_file) except Exception as e: errorMsg = " Unable to set config file.key is "+ key +"with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.log(errorMsg, 'Warning') return value def get_machine_id(self): machine_id_file = "/etc/azure/machine_identity_FD76C85E-406F-4CFA-8EB0-CF18B123358B" machine_id = "" file_pointer = None try: if not os.path.exists(os.path.dirname(machine_id_file)): os.makedirs(os.path.dirname(machine_id_file)) if os.path.exists(machine_id_file): file_pointer = open(machine_id_file, "r") machine_id = file_pointer.readline() file_pointer.close() else: mi = MachineIdentity() if(mi.stored_identity() != None): machine_id = mi.stored_identity()[1:-1] file_pointer = open(machine_id_file, "w") file_pointer.write(machine_id) file_pointer.close() except Exception as e: errMsg = 'Failed to retrieve the unique machine id with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg, 'Error') finally : if file_pointer != None : if file_pointer.closed == False : file_pointer.close() self.log("Unique Machine Id : {0}".format(machine_id)) return machine_id def get_storage_details(self,total_size,failure_flag): self.storageDetailsObj = Utils.Status.StorageDetails(self.partitioncount, total_size, False, failure_flag) self.log("partition count : {0}, total used size : {1}, is storage space present : {2}, is size computation failed : {3}".format(self.storageDetailsObj.partitionCount, self.storageDetailsObj.totalUsedSizeInBytes, self.storageDetailsObj.isStoragespacePresent, self.storageDetailsObj.isSizeComputationFailed)) return self.storageDetailsObj def SetExtErrorCode(self, extErrorCode): if self.ExtErrorCode == ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.success : self.ExtErrorCode = extErrorCode def SetSnapshotConsistencyType(self, snapshotConsistency): self.SnapshotConsistency = snapshotConsistency def SetHealthStatusCode(self, healthStatusCode): self.HealthStatusCode = healthStatusCode def do_status_json(self, operation, status, sub_status, status_code, message, telemetrydata, taskId, commandStartTimeUTCTicks, snapshot_info, vm_health_obj,total_size,failure_flag): tstamp = time.strftime(DateTimeFormat, time.gmtime()) formattedMessage = Utils.Status.FormattedMessage("en-US",message) stat_obj = Utils.Status.StatusObj(self._context._name, operation, status, sub_status, status_code, formattedMessage, telemetrydata, self.get_storage_details(total_size,failure_flag), self.get_machine_id(), taskId, commandStartTimeUTCTicks, snapshot_info, vm_health_obj) top_stat_obj = Utils.Status.TopLevelStatus(self._context._version, tstamp, stat_obj) return top_stat_obj def get_extension_version(self): try: cur_dir = os.getcwd() cur_extension = cur_dir.split("/")[-1] extension_version = cur_extension.split("-")[-1] return extension_version except Exception as e: errMsg = 'Failed to retrieve the Extension version with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg) extension_version="Unknown" return extension_version def get_wala_version(self): try: file_pointer = open('/var/log/waagent.log','r') waagent_version = '' for line in file_pointer: if 'Azure Linux Agent Version' in line: waagent_version = line.split(':')[-1] if waagent_version[:-1]=="": #for removing the trailing '\n' character waagent_version = self.get_wala_version_from_command() return waagent_version else: waagent_version = waagent_version[:-1].split("-")[-1] #getting only version number return waagent_version except Exception as e: errMsg = 'Failed to retrieve the wala version with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg) waagent_version="Unknown" return waagent_version def get_wala_version_from_command(self): try: cur_dir = os.getcwd() os.chdir("..") out = self.command_output_from_subprocess(['/usr/sbin/waagent', '-version'],30) if "Goal state agent: " in out: waagent_version = out.split("Goal state agent: ")[1].strip() else: out = out.split(" ") waagent = out[0] waagent_version = waagent.split("-")[-1] #getting only version number os.chdir(cur_dir) return waagent_version except Exception as e: errMsg = 'Failed to retrieve the wala version with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg) os.chdir(cur_dir) waagent_version="Unknown" return waagent_version def get_dist_info(self): try: if 'FreeBSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) return "FreeBSD",release if 'NS-BSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) return "NS-BSD", release # Try modern approach first (Python 3.8+ compatible) if HAS_DISTRO: try: distro_name = distro.name() distro_version = distro.version() if distro_name and distro_version: return distro_name + "-" + distro_version, platform.release() except Exception as e: self.log('Warning: distro package failed with error: %s' % str(e)) # Fallback to linux_distribution (deprecated in Python 3.5, removed in Python 3.8) if hasattr(platform, 'linux_distribution'): try: distinfo = list(platform.linux_distribution(full_distribution_name=0)) # remove trailing whitespace in distro name if(distinfo[0] == ''): osfile= open("/etc/os-release", "r") for line in osfile: lists=str(line).split("=") if(lists[0]== "NAME"): distroname = lists[1].split("\"") if(lists[0]=="VERSION"): distroversion = lists[1].split("\"") osfile.close() return distroname[1]+"-"+distroversion[1],platform.release() distinfo[0] = distinfo[0].strip() return distinfo[0]+"-"+distinfo[1],platform.release() except Exception as e: self.log('Warning: platform.linux_distribution failed with error: %s' % str(e)) # Fallback to platform.dist() (deprecated in Python 3.5, removed in Python 3.8+) if hasattr(platform, 'dist'): try: distinfo = platform.dist() return distinfo[0]+"-"+distinfo[1],platform.release() except Exception as e: self.log('Warning: platform.dist failed with error: %s' % str(e)) # Final fallback: try to parse /etc/os-release manually try: distroname = None distroversion = None with open("/etc/os-release", "r") as osfile: for line in osfile: lists = str(line.strip()).split("=", 1) if len(lists) >= 2: key = lists[0].strip() value = lists[1].strip().strip('"') if key == "NAME": distroname = value elif key == "VERSION" or key == "VERSION_ID": distroversion = value if distroname and distroversion: return distroname + "-" + distroversion, platform.release() elif distroname: return distroname + "-Unknown", platform.release() except Exception as e: self.log('Warning: Failed to parse /etc/os-release with error: %s' % str(e)) # If all else fails, return unknown return "Unknown", "Unknown" except Exception as e: errMsg = 'Failed to retrieve the distinfo with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg) return "Unknown","Unknown" def substat_new_entry(self,sub_status,code,name,status,formattedmessage): sub_status_obj = Utils.Status.SubstatusObj(code,name,status,formattedmessage) sub_status.append(sub_status_obj) return sub_status def timedelta_total_seconds(self, delta): if not hasattr(datetime.timedelta, 'total_seconds'): return delta.days * 86400 + delta.seconds else: return delta.total_seconds() @staticmethod def add_to_telemetery_data(key,value): HandlerUtility.telemetry_data[key]=value @staticmethod def get_telemetry_data(key): return HandlerUtility.telemetry_data[key] def add_telemetry_data(self): os_version,kernel_version = self.get_dist_info() workloads = self.get_workload_running() HandlerUtility.add_to_telemetery_data("guestAgentVersion",self.get_wala_version_from_command()) HandlerUtility.add_to_telemetery_data("extensionVersion",self.get_extension_version()) HandlerUtility.add_to_telemetery_data("osVersion",os_version) HandlerUtility.add_to_telemetery_data("kernelVersion",kernel_version) HandlerUtility.add_to_telemetery_data("workloads",str(workloads)) HandlerUtility.add_to_telemetery_data("prePostEnabled", str(self.pre_post_enabled)) def convert_telemetery_data_to_bcm_serializable_format(self): HandlerUtility.serializable_telemetry_data = [] for k,v in HandlerUtility.telemetry_data.items(): each_telemetry_data = {} each_telemetry_data["Value"] = v each_telemetry_data["Key"] = k HandlerUtility.serializable_telemetry_data.append(each_telemetry_data) def do_status_report(self, operation, status, status_code, message, taskId = None, commandStartTimeUTCTicks = None, snapshot_info = None,total_size = 0,failure_flag = True ): self.log("{0},{1},{2},{3},{4}".format(operation, status, status_code, message, failure_flag )) sub_stat = [] stat_rept = [] self.add_telemetry_data() snapshotTelemetry = "" if CommonVariables.snapshotCreator in HandlerUtility.telemetry_data.keys(): snapshotTelemetry = "{0}{1}={2}, ".format(snapshotTelemetry , CommonVariables.snapshotCreator , HandlerUtility.telemetry_data[CommonVariables.snapshotCreator]) if CommonVariables.hostStatusCodePreSnapshot in HandlerUtility.telemetry_data.keys(): snapshotTelemetry = "{0}{1}={2}, ".format(snapshotTelemetry , CommonVariables.hostStatusCodePreSnapshot , HandlerUtility.telemetry_data[CommonVariables.hostStatusCodePreSnapshot]) if CommonVariables.hostStatusCodeDoSnapshot in HandlerUtility.telemetry_data.keys(): snapshotTelemetry = "{0}{1}={2}, ".format(snapshotTelemetry , CommonVariables.hostStatusCodeDoSnapshot , HandlerUtility.telemetry_data[CommonVariables.hostStatusCodeDoSnapshot]) if CommonVariables.statusBlobUploadError in HandlerUtility.telemetry_data.keys(): message = "{0} {1}={2}, ".format(message , CommonVariables.statusBlobUploadError , HandlerUtility.telemetry_data[CommonVariables.statusBlobUploadError]) message = message + snapshotTelemetry vm_health_obj = Utils.Status.VmHealthInfoObj(ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.ExtensionErrorCodeDict[self.ExtErrorCode], int(self.ExtErrorCode)) consistencyTypeStr = CommonVariables.consistency_crashConsistent if (self.SnapshotConsistency != Utils.Status.SnapshotConsistencyType.crashConsistent): if (status_code == CommonVariables.success_appconsistent): self.SnapshotConsistency = Utils.Status.SnapshotConsistencyType.applicationConsistent consistencyTypeStr = CommonVariables.consistency_applicationConsistent elif (status_code == CommonVariables.success): self.SnapshotConsistency = Utils.Status.SnapshotConsistencyType.fileSystemConsistent consistencyTypeStr = CommonVariables.consistency_fileSystemConsistent else: self.SnapshotConsistency = Utils.Status.SnapshotConsistencyType.none consistencyTypeStr = CommonVariables.consistency_none HandlerUtility.add_to_telemetery_data("consistencyType", consistencyTypeStr) extensionResponseObj = Utils.Status.ExtensionResponse(message, self.SnapshotConsistency, "", failure_flag) message = str(json.dumps(extensionResponseObj, cls = ComplexEncoder)) self.convert_telemetery_data_to_bcm_serializable_format() stat_rept = self.do_status_json(operation, status, sub_stat, status_code, message, HandlerUtility.serializable_telemetry_data, taskId, commandStartTimeUTCTicks, snapshot_info, vm_health_obj, total_size,failure_flag) time_delta = datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1) time_span = self.timedelta_total_seconds(time_delta) * 1000 date_place_holder = 'e2794170-c93d-4178-a8da-9bc7fd91ecc0' stat_rept.timestampUTC = date_place_holder date_string = r'\/Date(' + str((int)(time_span)) + r')\/' # Convert TopLevelStatus object to JSON array string # Before: stat_rept is TopLevelStatus object with timestampUTC="e2794170-c93d-4178-a8da-9bc7fd91ecc0" # After: stat_rept = '[{"version":"1.0","timestampUTC":"e2794170-c93d-4178-a8da-9bc7fd91ecc0","status":{"name":"VMSnapshotLinux",...}}]' stat_rept = "[" + json.dumps(stat_rept, cls = ComplexEncoder) + "]" # Replace placeholder GUID with actual Microsoft JSON date format first # Before: "timestampUTC":"e2794170-c93d-4178-a8da-9bc7fd91ecc0" # After: "timestampUTC":"\/Date(time_span)\/" stat_rept = stat_rept.replace(date_place_holder,date_string) # Now remove JSON-escaped forward slashes to get clean date format for C# DateTimeOffset # Before: "timestampUTC":"\/Date(time_span)\/" # After: "timestampUTC":"/Date(time_span)/" stat_rept = stat_rept.replace(r'\/', '/') # To fix the datetime format of CreationTime to be consumed by C# DateTimeOffset # Add Status as sub-status for Status to be written on Status-File sub_stat = self.substat_new_entry(sub_stat,'0',stat_rept,'success',None) if self.get_public_settings()[CommonVariables.vmType].lower() == CommonVariables.VmTypeV2.lower() and CommonVariables.isTerminalStatus(status) : status = CommonVariables.status_success stat_rept_file = self.do_status_json(operation, status, sub_stat, status_code, message, None, taskId, commandStartTimeUTCTicks, None, None,total_size,failure_flag) stat_rept_file = "[" + json.dumps(stat_rept_file, cls = ComplexEncoder) + "]" # rename all other status files, or the WALA would report the wrong # status file. # because the wala choose the status file with the highest sequence # number to report. return stat_rept, stat_rept_file def write_to_status_file(self, stat_rept_file): try: tempStatusFile = os.path.join(self._context._status_dir, CommonVariables.TempStatusFileName) if self._context._status_file: with open(tempStatusFile,'w+') as f: f.write(stat_rept_file) os.rename(tempStatusFile, self._context._status_file) except Exception as e: errMsg = 'Status file creation failed with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.log(errMsg) def is_status_file_exists(self): try: if os.path.exists(self._context._status_file): return True else: return False except Exception as e: self.log("exception is getting status file" + traceback.format_exc()) return False # Rename all .settings and .status files that do not belong to current seq# with '_' suffix # Clear older files in the respective directories def backup_settings_status_file(self, _seq_no): self.log("current seq no is " + _seq_no) file_extn_settings = '.settings' file_extn_status = '.status' maxLimitOfFiles = 60 for subdir, dirs, files in os.walk(self._context._config_dir): for file in files: try: if(file.endswith(file_extn_settings) and file != (_seq_no + file_extn_settings)): new_file_name = file.replace(".","_") os.rename(join(self._context._config_dir,file), join(self._context._config_dir,new_file_name)) except Exception as e: self.log("failed to rename the status file.") try: FileHelpers.clearOldFilesInDirectory(self._context._config_dir, '_settings', maxLimitOfFiles) except Exception as e: pass # Ignore the exception in clearing old files and continue for subdir, dirs, files in os.walk(self._context._status_dir): for file in files: try: if(file.endswith(file_extn_status) and file != (_seq_no + file_extn_status)): new_file_name = file.replace(".","_") os.rename(join(self._context._status_dir,file), join(self._context._status_dir, new_file_name)) except Exception as e: self.log("failed to rename the status file.") try: FileHelpers.clearOldFilesInDirectory(self._context._status_dir, '_status', maxLimitOfFiles) except Exception as e: pass # Ignore the exception in clearing old files and continue def do_exit(self, exit_code, operation,status,code,message): try: HandlerUtility.add_to_telemetery_data("extErrorCode", str(ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.ExtensionErrorCodeDict[self.ExtErrorCode])) self.do_status_report(operation, status,code,message) except Exception as e: self.log("Can't update status: " + str(e)) if(self.eventlogger is not None): self.eventlogger.dispose() sys.exit(exit_code) def get_handler_settings(self): return self._context._config['runtimeSettings'][0]['handlerSettings'] def get_protected_settings(self): return self.get_handler_settings().get('protectedSettings') def get_public_settings(self): return self.get_handler_settings().get('publicSettings') def is_prev_in_transition(self): curr_seq = self.get_last_seq() last_seq = curr_seq - 1 if last_seq >= 0: self.log("previous status and path: " + str(last_seq) + " " + str(self._context._status_dir)) status_file_prev = os.path.join(self._context._status_dir, str(last_seq) + '_status') if os.path.isfile(status_file_prev) and os.access(status_file_prev, os.R_OK): searchfile = open(status_file_prev, "r") for line in searchfile: if "Transition" in line: self.log("transitioning found in the previous status file") searchfile.close() return True searchfile.close() return False def get_prev_log(self): with open(self._context._log_file, "r") as f: lines = f.readlines() if(len(lines) > 300): lines = lines[-300:] return ''.join(str(x) for x in lines) else: return ''.join(str(x) for x in lines) def get_shell_script_log(self): lines = "" try: with open(self._context._shell_log_file, "r") as f: lines = f.readlines() if(len(lines) > 10): lines = lines[-10:] return ''.join(str(x) for x in lines) except Exception as e: self.log("Can't receive shell log file: " + str(e)) return lines def update_settings_file(self): if(self._context._config['runtimeSettings'][0]['handlerSettings'].get('protectedSettings') != None): del self._context._config['runtimeSettings'][0]['handlerSettings']['protectedSettings'] self.log("removing the protected settings") waagent.SetFileContents(self._context._settings_file,json.dumps(self._context._config)) def UriHasSpecialCharacters(self, blobs): uriHasSpecialCharacters = False if blobs is not None: for blob in blobs: blobUri = str(blob.split("?")[0]) if '%' in blobUri: self.log(blobUri + " URI has special characters") uriHasSpecialCharacters = True return uriHasSpecialCharacters def get_workload_running(self): workloads = [] try: dblist= ["mysqld","postgresql","oracle","cassandra",",mongo"] ## add all workload process name in lower case if os.path.isdir("/proc"): pids = [pid for pid in os.listdir('/proc') if pid.isdigit()] for pid in pids: pname = open(os.path.join('/proc', pid, 'cmdline'), 'rb').read() for db in dblist : if db in str(pname).lower() and db not in workloads : self.log("workload running found with name : " + str(db)) workloads.append(db) return workloads except Exception as e: self.log("Unable to fetch running workloads" + str(e)) return workloads def set_pre_post_enabled(self): self.pre_post_enabled = True def command_output_from_subprocess(self , args, process_wait_time): process_out = subprocess.Popen(args, stdout=subprocess.PIPE) while(process_wait_time > 0 and process_out.poll() is None): time.sleep(1) process_wait_time -= 1 out = process_out.stdout.read().decode() out = str(out) return out def get_severity_level(self): logging_level = LoggingLevel(LoggingConstants.DefaultEventLogLevel) try: log_setting_file_path = os.path.join(os.getcwd(), "main", LoggingConstants.LogLevelSettingFile) if os.path.exists(log_setting_file_path): with open(log_setting_file_path, 'r') as file: logging_level_input = json.load(file) logging_level.__dict__.update(logging_level_input) else: self.log("Logging level setting file is not present.") except Exception as e: self.log("error in fetching the severity of logs " + str(e)) return logging_level.EventLogLevel @staticmethod def split(logger,txt): result = None try: result = shlex.split(txt) except Exception as e: logger.log('Shlex.Split threw exception error: %s, stack trace: %s' % (str(e), traceback.format_exc())) result = txt.split() return result @staticmethod def convert_to_string(txt): if sys.version_info > (3,): txt = str(txt, encoding='utf-8', errors="backslashreplace") else: txt = str(txt) return txt def redact_sensitive_encryption_details(self, request_body): try: meta_list = getattr(request_body, "snapshotMetadata", None) for meta in meta_list: if meta.get("Key") == "DiskEncryptionSettings": # Redact the entire value of DiskEncryptionSettings meta["Value"] = "REDACTED" return request_body except Exception as e: self.log("Error while redacting: {0}".format(str(e)), 'Error') return request_body class ComplexEncoder(json.JSONEncoder): def default(self, obj): if hasattr(obj,'convertToDictionary'): return obj.convertToDictionary() else: return obj.__dict__ ================================================ FILE: VMBackup/main/Utils/HostSnapshotObjects.py ================================================ import json class HostDoSnapshotRequestBody: def __init__(self, taskId, diskIds, settings, snapshotTaskToken, snapshotMetadata, instantAccessDurationMinutes = None): self.taskId = taskId self.diskIds = diskIds self.snapshotMetadata = snapshotMetadata self.snapshotTaskToken = snapshotTaskToken self.settings = settings self.instantAccessDurationMinutes = instantAccessDurationMinutes def convertToDictionary(self): result = dict(taskId = self.taskId, diskIds = self.diskIds, settings = self.settings, snapshotTaskToken = self.snapshotTaskToken, snapshotMetadata = self.snapshotMetadata) if self.instantAccessDurationMinutes is not None: result['instantAccessDurationMinutes'] = self.instantAccessDurationMinutes return result class HostPreSnapshotRequestBody: def __init__(self, taskId, snapshotTaskToken, preSnapshotSettings = None): self.taskId = taskId self.snapshotTaskToken = snapshotTaskToken if (preSnapshotSettings != None): self.preSnapshotSettings = preSnapshotSettings def convertToDictionary(self): result = dict(taskId=self.taskId, snapshotTaskToken=self.snapshotTaskToken) if hasattr(self, 'preSnapshotSettings'): result['preSnapshotSettings'] = self.preSnapshotSettings return result class BlobSnapshotInfo: def __init__(self, isSuccessful, snapshotUri, errorMessage, statusCode, ddSnapshotIdentifier = None): self.isSuccessful = isSuccessful self.snapshotUri = snapshotUri self.errorMessage = errorMessage self.statusCode = statusCode self.ddSnapshotIdentifier = ddSnapshotIdentifier def convertToDictionary(self): return dict(isSuccessful = self.isSuccessful, snapshotUri = self.snapshotUri, errorMessage = self.errorMessage, statusCode = self.statusCode, ddSnapshotIdentifier = self.ddSnapshotIdentifier) class DDSnapshotIdentifier: def __init__(self, creationTime, id, token, instantAccessDurationMinutes = None): self.creationTime = creationTime self.id = id self.token = token self.instantAccessDurationMinutes = instantAccessDurationMinutes def convertToDictionary(self): return dict(creationTime = self.creationTime, id = self.id, token = self.token, instantAccessDurationMinutes = self.instantAccessDurationMinutes) ================================================ FILE: VMBackup/main/Utils/LogHelper.py ================================================ import os import datetime import shutil import time class LoggingConstants: MaxDayAgeOfStaleFiles = -1 # We don't store unprocessed files beyond 1 day from current processing time LogFileWriteRetryAttempts = 3 LogFileWriteRetryTime = 500 # milliseconds MaxAttemptsForEventFileCreationWriteMove = 3 MinEventProcesingInterval = 10 # 10 seconds ThreadSleepDuration = 10 # 10 seconds MaxEventDirectorySize = 39981250 # ~= 39Mb MaxEventsPerRun = 300 MaxMessageLenLimit = 2900 # 3072 to be precise MaxMessageLengthPerEvent = 3000 # 3072 to be precise DefaultEventTaskName = "Enable" # ToDo: The third param-TaskName is by default set to "Enable". We can add a mechanism to send the program file name LogLevelSettingFile = "LogSeverity.json" DefaultEventLogLevel = 2 AllLogEnabledLevel = 0 class LoggingLevel: def __init__(self, event_log_level): self.EventLogLevel = event_log_level class FileHelpers: @staticmethod def getSizeOfDir(path): total_size = 0 for root, dirs, files in os.walk(path): for file in files: file_path = os.path.join(root, file) total_size += os.path.getsize(file_path) return total_size @staticmethod def deleteFile(file_path): if os.path.exists(file_path): try: os.remove(file_path) print("Information: Successfully deleted file: {0}".format(file_path)) except Exception as ex: print("Warning: Failed to delete file {0}. Exception: {1}".format(file_path, str(ex))) else: print("Error: Attempted to delete non-existent file: {0}".format(file_path)) @staticmethod def deleteDirectory(directory_path): if os.path.exists(directory_path): try: shutil.rmtree(directory_path) print("Information: Successfully deleted directory: {0}".format(directory_path)) except Exception as ex: print("Warning: Failed to delete directory {0}. Exception: {1}".format(directory_path, str(ex))) else: print("Error: Attempted to delete non-existent directory: {0}".format(directory_path)) @staticmethod def clearOldJsonFilesInDirectory(file_path): try: current_time = datetime.datetime.now() max_day_age = datetime.timedelta(days=LoggingConstants.MaxDayAgeOfStaleFiles) files_deleted = 0 for root, dirs, files in os.walk(file_path): for file in files: file_path = os.path.join(root, file) last_write_time = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)) if last_write_time < current_time + max_day_age: try: os.remove(file_path) files_deleted += 1 except Exception as ex: print("Warning: Failed to delete old JSON file {0}. Exception: {1}".format(file_path)) print("Information: Cleared {0} day old JSON files in directory at path {1}, NumberOfFilesRemoved/NumberOfJSONFilesPresent = {2}/{3}".format(LoggingConstants.MaxDayAgeOfStaleFiles, file_path, files_deleted, len(files))) except Exception as ex: print("Warning: Failed to delete old JSON files at path {0}. Exception: {1}".format(file_path, str(ex))) @staticmethod def clearOldFilesInDirectory(directory, extension, file_limit): """ Deletes older files if the number of files with the given extension exceeds the file_limit. Parameters: directory (str): The directory to clean up. extension (str): The file extension to filter (e.g., ".status", ".settings"). file_limit (int): Maximum allowed number of files with the given extension. """ try: # Ensure the directory exists if not os.path.isdir(directory): print("Directory '{0}' does not exist.".format(directory)) return # Collect all files with the specified extension files_with_ext = [ os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension) and os.path.isfile(os.path.join(directory, f)) ] # Sort the files by modification time (oldest first) files_with_ext.sort(key=lambda f: os.path.getmtime(f)) # Check if the number of files exceeds the limit if len(files_with_ext) > file_limit: files_to_delete = files_with_ext[:len(files_with_ext) - file_limit] # Delete the excess files for file in files_to_delete: try: os.remove(file) print("Deleted: {0}".format(file)) except Exception as e: print("Error deleting {0}: {1}".format(file, str(e))) else: print("No files need to be deleted. Total files ({0}) are within the limit.".format(len(files_with_ext))) except Exception as e: print("An unexpected error occurred while clearing old files: {0}".format(str(e))) def execute_with_retries(self, max_attempts, delay, success_msg, retry_msg, err_msg, operation): attempts = 0 while attempts < max_attempts: try: result = operation() print("Information: " + success_msg) return result except Exception as ex: attempts += 1 print("Warning: {0}, Exception: {1}".format(retry_msg, str(ex))) if attempts < max_attempts: time.sleep(delay) print("Warning: " + err_msg) return None ================================================ FILE: VMBackup/main/Utils/ResourceDiskUtil.py ================================================ import os import sys import re import subprocess import shlex from subprocess import * import traceback from Utils.DiskUtil import DiskUtil import Utils.HandlerUtil STORAGE_DEVICE_PATH = '/sys/bus/vmbus/devices/' GEN2_DEVICE_ID = 'f8b3781a-1e82-4818-a1c3-63d806ec15bb' def read_file(filepath): """ Read and return contents of 'filepath'. """ mode = 'rb' with open(filepath, mode) as in_file: data = in_file.read().decode('utf-8') return data class ResourceDiskUtil(object): def __init__(self,patching,logger): self.logger = logger self.disk_util = DiskUtil.get_instance(patching,logger) @staticmethod def _enumerate_device_id(): """ Enumerate all storage device IDs. Args: None Returns: Iterator[Tuple[str, str]]: VmBus and storage devices. """ if os.path.exists(STORAGE_DEVICE_PATH): for vmbus in os.listdir(STORAGE_DEVICE_PATH): deviceid = read_file(filepath=os.path.join(STORAGE_DEVICE_PATH, vmbus, "device_id")) guid = deviceid.strip('{}\n') yield vmbus, guid @staticmethod def search_for_resource_disk(gen1_device_prefix, gen2_device_id): """ Search the filesystem for a device by ID or prefix. Args: gen1_device_prefix (str): Gen1 resource disk prefix. gen2_device_id (str): Gen2 resource device ID. Returns: str: The found device. """ device = None # We have to try device IDs for both Gen1 and Gen2 VMs. #ResourceDiskUtil.logger.log('Searching gen1 prefix {0} or gen2 {1}'.format(gen1_device_prefix, gen2_device_id),True) try: # pylint: disable=R1702 for vmbus, guid in ResourceDiskUtil._enumerate_device_id(): if guid.startswith(gen1_device_prefix) or guid == gen2_device_id: for root, dirs, files in os.walk(STORAGE_DEVICE_PATH + vmbus): # pylint: disable=W0612 root_path_parts = root.split('/') # For Gen1 VMs we only have to check for the block dir in the # current device. But for Gen2 VMs all of the disks (sda, sdb, # sr0) are presented in this device on the same SCSI controller. # Because of that we need to also read the LUN. It will be: # 0 - OS disk # 1 - Resource disk # 2 - CDROM if root_path_parts[-1] == 'block' and ( # pylint: disable=R1705 guid != gen2_device_id or root_path_parts[-2].split(':')[-1] == '1'): device = dirs[0] return device else: # older distros for d in dirs: # pylint: disable=C0103 if ':' in d and "block" == d.split(':')[0]: device = d.split(':')[1] return device except (OSError, IOError) as exc: err_msg='Error getting device for %s or %s: %s , Stack Trace: %s' % (gen1_device_prefix, gen2_device_id, str(exc),traceback.format_exc()) return None def device_for_ide_port(self): """ Return device name attached to ide port 'n'. gen1 device prefix is the prefix of the file name in which the resource disk partition is stored eg sdb gen1 is for new distros In old distros the directory name which contains resource disk partition is assigned to gen2 device id """ g0 = "00000000" gen1_device_prefix = '{0}-0001'.format(g0) self.logger.log('Searching gen1 prefix {0} or gen2 {1}'.format(gen1_device_prefix, GEN2_DEVICE_ID),True) device = self.search_for_resource_disk( gen1_device_prefix=gen1_device_prefix, gen2_device_id=GEN2_DEVICE_ID ) self.logger.log('Found device: {0}'.format(device),True) return device def get_mount_point(self, mountlist, device): """ Example of mountlist: /dev/sda1 on / type ext4 (rw) proc on /proc type proc (rw) sysfs on /sys type sysfs (rw) devpts on /dev/pts type devpts (rw,gid=5,mode=620) tmpfs on /dev/shm type tmpfs (rw,rootcontext="system_u:object_r:tmpfs_t:s0") none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw) /dev/sdb1 on /mnt/resource type ext4 (rw) """ if (mountlist and device): for entry in mountlist.split('\n'): if(re.search(device, entry)): tokens =Utils.HandlerUtil.HandlerUtility.split(self.logger, entry) #Return the 3rd column of this line return tokens[2] if len(tokens) > 2 else None return None def get_resource_disk_mount_point(self,option=1): # pylint: disable=R0912,R0914 try: """ if option = 0 then partition will be returned eg sdb1 if option = 1 then mount point will be returned eg /mnt/resource """ device = self.device_for_ide_port() if device is None: self.logger.log('unable to detect disk topology',True,'Error') if device is not None: partition = "{0}{1}".format(device,"1") #assuming only one resourde disk partition else: partition="" self.logger.log("Resource disk partition: {0} ".format(partition),True) if(option==0): return partition #p = Popen("mount", stdout=subprocess.PIPE, stderr=subprocess.PIPE) #mount_list, err = p.communicate() mount_list = self.disk_util.get_mount_output() if(mount_list is not None): mount_point = self.get_mount_point(mountlist = mount_list, device = device) self.logger.log("Resource disk {0} is mounted {1}".format(partition,mount_point),True) if mount_point: return mount_point return None except Exception as e: err_msg='Cannot get Resource disk partition, Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(err_msg, True, 'Error') return None ================================================ FILE: VMBackup/main/Utils/SizeCalculation.py ================================================ import os import os.path import sys try: import imp as imp except ImportError: import importlib as imp try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import base64 import json import tempfile import time from Utils.DiskUtil import DiskUtil from Utils.ResourceDiskUtil import ResourceDiskUtil import Utils.HandlerUtil import traceback import subprocess import shlex from common import CommonVariables class SizeCalculation(object): def __init__(self,patching, hutil, logger,para_parser): self.patching = patching self.logger = logger self.hutil = hutil self.includedLunList = [] self.file_systems_info = [] self.non_physical_file_systems = ['fuse', 'nfs', 'cifs', 'overlay', 'aufs', 'lustre', 'secfs2', 'zfs', 'btrfs', 'iso'] self.known_fs = ['ext3', 'ext4', 'jfs', 'xfs', 'reiserfs', 'devtmpfs', 'tmpfs', 'rootfs', 'fuse', 'nfs', 'cifs', 'overlay', 'aufs', 'lustre', 'secfs2', 'zfs', 'btrfs', 'iso'] self.isOnlyOSDiskBackupEnabled = False try: if(para_parser.customSettings != None and para_parser.customSettings != ''): self.logger.log('customSettings : ' + str(para_parser.customSettings)) customSettings = json.loads(para_parser.customSettings) if("isOnlyOSDiskBackupEnabled" in customSettings): self.isOnlyOSDiskBackupEnabled = customSettings["isOnlyOSDiskBackupEnabled"] if(self.isOnlyOSDiskBackupEnabled == True): Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("billingType","os disk") else: Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("billingType","none") self.logger.log("isOnlyOSDiskBackupEnabled : {0}".format(str(self.isOnlyOSDiskBackupEnabled))) except Exception as e: errMsg = 'Failed to serialize customSettings with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') self.isOnlyOSDiskBackupEnabled = False self.disksToBeIncluded = [] self.root_devices = [] self.root_mount_points = ['/' , '/boot/efi'] self.devicesToInclude = [] #partitions to be included self.device_mount_points = [] self.isAnyDiskExcluded = False self.LunListEmpty = False self.logicalVolume_to_bill = [] self.sudo_off = 0 # run the commands using sudo if(para_parser.includedDisks != None and para_parser.includedDisks != '' and CommonVariables.isAnyDiskExcluded in para_parser.includedDisks.keys() and para_parser.includedDisks[CommonVariables.isAnyDiskExcluded] != None ): self.isAnyDiskExcluded = para_parser.includedDisks[CommonVariables.isAnyDiskExcluded] self.logger.log("isAnyDiskExcluded {0}".format(self.isAnyDiskExcluded)) if( para_parser.includeLunList != None and para_parser.includeLunList != ''): self.includedLunList = para_parser.includeLunList self.logger.log("includedLunList {0}".format(self.includedLunList)) if(self.includedLunList == None or len(self.includedLunList) == 0): self.LunListEmpty = True self.logger.log("As the LunList is empty including all disks") def get_lsscsi_list(self): command = "lsscsi" if (self.sudo_off == 0): command = "sudo " + command try: self.logger.log("executing command {0}".format(command)) self.lsscsi_list = (os.popen(command).read()).splitlines() except Exception as e: error_msg = "Failed to execute the command \"%s\" because of error %s , stack trace: %s" % (command, str(e), traceback.format_exc()) self.logger.log(error_msg, True ,'Error') self.lsscsi_list = [] def get_lsblk_list(self): try: self.output_lsblk = os.popen("lsblk -n --list --output name,mountpoint").read().strip().splitlines() except Exception as e: error_msg = "Failed to execute the command lsblk -n --list --output name,mountpoint because of error %s , stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(error_msg, True ,'Error') self.output_lsblk = [] def get_pvs_list(self): try: command = "pvs" if (self.sudo_off == 0): command = "sudo " + command self.pvs_output = os.popen(command).read().strip().split("\n") self.pvs_output = self.pvs_output[1:] except Exception as e: error_msg = "Failed to execute the command \"%s\" because of error %s , stack trace: %s" % (command, str(e), traceback.format_exc()) self.logger.log(error_msg, True ,'Error') self.pvs_output = [] def get_loop_devices(self): global disk_util disk_util = DiskUtil.get_instance(patching = self.patching,logger = self.logger) if len(self.file_systems_info) == 0 : self.file_systems_info = disk_util.get_mount_file_systems() self.logger.log("file_systems list : ",True) self.logger.log(str(self.file_systems_info),True) disk_loop_devices_file_systems = [] for file_system_info in self.file_systems_info: if 'loop' in file_system_info[0]: disk_loop_devices_file_systems.append(file_system_info[0]) return disk_loop_devices_file_systems def disk_list_for_billing(self): if(len(self.lsscsi_list) != 0): for item in self.lsscsi_list: idxOfColon = item.rindex(':',0,item.index(']'))# to get the index of last ':' idxOfColon += 1 lunNumber = int(item[idxOfColon:item.index(']')]) # item_split is the list of elements present in the one row of the cmd sudo lsscsi self.item_split = item.split() #storing the corresponding device name from the list device_name = self.item_split[len(self.item_split)-1] for device in self.root_devices : if device_name in device : lunNumber = -1 # Changing the Lun# of OS Disk to -1 if lunNumber in self.includedLunList : self.disksToBeIncluded.append(device_name) self.logger.log("LUN Number {0}, disk {1}".format(lunNumber,device_name)) self.logger.log("Disks to be included {0}".format(self.disksToBeIncluded)) else: self.size_calc_failed = True self.logger.log("There is some glitch in executing the command 'lsscsi' and therefore size calculation is marked as failed.") def get_logicalVolumes_for_billing(self): try: self.pvs_dict = {} for pvs_item in self.pvs_output: pvs_item_split = pvs_item.strip().split() if(len(pvs_item_split) > 2): physicalVolume = pvs_item_split[0] logicalVolumeGroup = pvs_item_split[1] if(logicalVolumeGroup in self.pvs_dict.keys()): self.pvs_dict.get(logicalVolumeGroup).append(physicalVolume) else: self.pvs_dict[logicalVolumeGroup] = [] self.pvs_dict.get(logicalVolumeGroup).append(physicalVolume) self.logger.log("The pvs_dict contains {0}".format(str(self.pvs_dict))) except Exception as e: errMsg = 'Failed to serialize pvs_output with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') for lvg in self.pvs_dict.keys(): count = 0 for disk in self.disksToBeIncluded: for pv in self.pvs_dict[lvg]: if(disk in pv): count = count+1 if(count == len(self.pvs_dict[lvg])): lvg = "/dev/mapper/" + lvg self.logicalVolume_to_bill.append(lvg) else: self.logger.log("Partial snapshotting for the logical volume group {0} can't be taken".format(lvg)) self.logger.log("the lvm list to bill are {0}".format(self.logicalVolume_to_bill)) def device_list_for_billing(self): self.logger.log("In device_list_for_billing",True) devices_to_bill = [] #list to store device names to be billed device_items = disk_util.get_device_items(None) for device_item in device_items : if str(device_item.name).startswith("sd"): devices_to_bill.append("/dev/{0}".format(str(device_item.name))) else: self.logger.log("Not adding device {0} as it does not start with sd".format(str(device_item.name))) self.logger.log("Initial billing items {0}".format(devices_to_bill)) ''' Sample output for file_systems_info [('sysfs', 'sysfs', '/sys'), ('proc', 'proc', '/proc'), ('udev', 'devtmpfs', '/dev'),..] Since root devices are at mount points '/' and '/boot/efi' we use file_system_info to find the root_devices based on the mount points. ''' # check if user off the sudo usage in commands self.sudo_off = self.hutil.get_intvalue_from_configfile("sudo_off", self.sudo_off) self.logger.log("sudo flag is {0}".format(self.sudo_off)) # The command lsscsi is used for mapping the LUN numbers to the disk_names self.get_lsscsi_list() #populates self.lsscsi_list self.get_lsblk_list() #populates self.lsblk_list self.get_pvs_list()#populates pvs list for file_system in self.file_systems_info: if(file_system[2] in self.root_mount_points): self.root_devices.append(file_system[0]) self.logger.log("root_devices {0}".format(str(self.root_devices))) self.logger.log("lsscsi_list {0}".format(self.lsscsi_list)) ''' Sample output of the lsscsi command [1:0:0:15] disk Msft Virtual Disk 1.0 /dev/sda [1:0:0:18] disk Msft Virtual Disk 1.0 /dev/sdc ''' self.disk_list_for_billing() self.get_logicalVolumes_for_billing() self.logger.log("lsblk o/p {0}".format(self.output_lsblk)) self.logger.log("lvm {0}".format(self.logicalVolume_to_bill)) ''' NAME MOUNTPOINT sda sda1 /boot/efi sda2 /boot sda3 sda4 rootvg-tmplv /tmp rootvg-usrlv /usr rootvg-optlv /opt rootvg-homelv /home rootvg-varlv /var rootvg-rootlv / sdb sdb1 /mnt/resource sdc sdc1 /mnt/sdc1 sdc2 /mnt/sdc2 sdc3 sde sde1 /mnt/sde1 sdf sdg ''' if(len(self.output_lsblk) == 0): self.size_calc_failed = True self.logger.log("There is some glitch in executing the command 'lsblk -n --list --output name,mountpoint' and therefore size calculation is marked as failed.") for item in self.output_lsblk: item_split = item.split() if(len(item_split)==2): device = item_split[0] mount_point = item_split[1] else: mount_point = "" device = "" if device != '' and mount_point != '': device = '/dev/' + device for disk in self.disksToBeIncluded : if disk in device and device not in self.devicesToInclude: self.devicesToInclude.append(device) self.device_mount_points.append(mount_point) break self.logger.log("devices_to_bill: {0}".format(str(self.devicesToInclude)),True) self.logger.log("The mountpoints of devices to bill: {0}".format(str(self.device_mount_points)), True) self.logger.log("exiting device_list_for_billing",True) return devices_to_bill def get_total_used_size(self): try: self.size_calc_failed = False onlyLocalFilesystems = self.hutil.get_strvalue_from_configfile(CommonVariables.onlyLocalFilesystems, "False") # df command gives the information of all the devices which have mount points if onlyLocalFilesystems in ['True', 'true']: df = subprocess.Popen(["df" , "-kl"], stdout=subprocess.PIPE) else: df = subprocess.Popen(["df" , "-k"], stdout=subprocess.PIPE) ''' Sample output of the df command Filesystem Type 1K-blocks Used Avail Use% Mounted on /dev/sda2 xfs 52155392 3487652 48667740 7% / devtmpfs devtmpfs 7170976 0 7170976 0% /dev tmpfs tmpfs 7180624 0 7180624 0% /dev/shm tmpfs tmpfs 7180624 760496 6420128 11% /run tmpfs tmpfs 7180624 0 7180624 0% /sys/fs/cgroup /dev/sda1 ext4 245679 151545 76931 67% /boot /dev/sdb1 ext4 28767204 2142240 25140628 8% /mnt/resource /dev/mapper/mygroup-thinv1 xfs 1041644 33520 1008124 4% /bricks/brick1 /dev/mapper/mygroup-85197c258a54493da7880206251f5e37_0 xfs 1041644 33520 1008124 4% /run/gluster/snaps/85197c258a54493da7880206251f5e37/brick2 /dev/mapper/mygroup2-thinv2 xfs 15717376 5276944 10440432 34% /tmp/test /dev/mapper/mygroup2-63a858543baf4e40a3480a38a2f232a0_0 xfs 15717376 5276944 10440432 34% /run/gluster/snaps/63a858543baf4e40a3480a38a2f232a0/brick2 tmpfs tmpfs 1436128 0 1436128 0% /run/user/1000 //Centos72test/cifs_test cifs 52155392 4884620 47270772 10% /mnt/cifs_test2 ''' output = "" process_wait_time = 300 while(df is not None and process_wait_time >0 and df.poll() is None): time.sleep(1) process_wait_time -= 1 self.logger.log("df command executed for process wait time value" + str(process_wait_time), True) if(df is not None and df.poll() is not None): self.logger.log("df return code "+str(df.returncode), True) output = df.stdout.read().decode() if sys.version_info > (3,): try: output = str(output, encoding='utf-8', errors="backslashreplace") except: output = str(output) else: output = str(output) output = output.strip().split("\n") self.logger.log("output of df : {0}".format(str(output)),True) disk_loop_devices_file_systems = self.get_loop_devices() self.logger.log("outside loop device", True) total_used = 0 total_used_network_shares = 0 total_used_gluster = 0 total_used_loop_device=0 total_used_temporary_disks = 0 total_used_ram_disks = 0 total_used_unknown_fs = 0 actual_temp_disk_used = 0 total_sd_size = 0 network_fs_types = [] unknown_fs_types = [] excluded_disks_used = 0 totalSpaceUsed = 0 device_list = [] if len(self.file_systems_info) == 0 : self.file_systems_info = disk_util.get_mount_file_systems() output_length = len(output) index = 1 self.resource_disk = ResourceDiskUtil(patching = self.patching, logger = self.logger) resource_disk_device = self.resource_disk.get_resource_disk_mount_point(0) self.logger.log("resource_disk_device: {0}".format(resource_disk_device),True) resource_disk_device = "/dev/{0}".format(resource_disk_device) self.logger.log("ResourceDisk is excluded in billing as it represents the Actual Temporary disk") if(self.LunListEmpty != True and self.isAnyDiskExcluded == True): device_list = self.device_list_for_billing() #new logic: calculate the disk size for billing while index < output_length: if(len(Utils.HandlerUtil.HandlerUtility.split(self.logger, output[index])) < 6 ): #when a row is divided in 2 lines index = index+1 if(index < output_length and len(Utils.HandlerUtil.HandlerUtility.split(self.logger, output[index-1])) + len(Utils.HandlerUtil.HandlerUtility.split(self.logger, output[index])) == 6): output[index] = output[index-1] + output[index] else: self.logger.log("Output of df command is not in desired format",True) total_used = 0 self.size_calc_failed = True break device, size, used, available, percent, mountpoint =Utils.HandlerUtil.HandlerUtility.split(self.logger, output[index]) fstype = '' isNetworkFs = False isKnownFs = False if int(used) < 0 : self.logger.log("The used space is negative, so marking the size computation as failed and returning zero") self.size_calc_failed = True return 0,self.size_calc_failed for file_system_info in self.file_systems_info: if device == file_system_info[0] and mountpoint == file_system_info[2]: fstype = file_system_info[1] self.logger.log("index :{0} Device name : {1} fstype : {2} size : {3} used space in KB : {4} available space : {5} mountpoint : {6}".format(index,device,fstype,size,used,available,mountpoint),True) for nonPhysicaFsType in self.non_physical_file_systems: if nonPhysicaFsType in fstype.lower(): isNetworkFs = True break for knownFs in self.known_fs: if knownFs in fstype.lower(): isKnownFs = True break if device == resource_disk_device and self.isOnlyOSDiskBackupEnabled == False : # adding log to check difference in billing of temp disk self.logger.log("Actual temporary disk, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) actual_temp_disk_used= int(used) if device in device_list and device != resource_disk_device : self.logger.log("Adding sd* partition, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_sd_size = total_sd_size + int(used) #calcutale total sd* size just skip temp disk if not (isKnownFs or fstype == '' or fstype == None): unknown_fs_types.append(fstype) if isNetworkFs : if fstype not in network_fs_types : network_fs_types.append(fstype) self.logger.log("Not Adding network-drive, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_used_network_shares = total_used_network_shares + int(used) elif device == "/dev/sdb1" and self.isOnlyOSDiskBackupEnabled == False : # in some cases root is mounted on /dev/sdb1 self.logger.log("Not Adding temporary disk, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_used_temporary_disks = total_used_temporary_disks + int(used) elif "tmpfs" in fstype.lower() or "devtmpfs" in fstype.lower() or "ramdiskfs" in fstype.lower() or "rootfs" in fstype.lower(): self.logger.log("Not Adding RAM disks, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_used_ram_disks = total_used_ram_disks + int(used) elif 'loop' in device and device not in disk_loop_devices_file_systems: self.logger.log("Not Adding Loop Device , Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_used_loop_device = total_used_loop_device + int(used) elif (mountpoint.startswith('/run/gluster/snaps/')): self.logger.log("Not Adding Gluster Device , Device name : {0} used space in KB : {1} mount point : {2}".format(device,used,mountpoint),True) total_used_gluster = total_used_gluster + int(used) elif device.startswith( '\\\\' ) or device.startswith( '//' ): self.logger.log("Not Adding network-drive as it starts with slahes, Device name : {0} used space in KB : {1} fstype : {2}".format(device,used,fstype),True) total_used_network_shares = total_used_network_shares + int(used) else: #Only when OS disk is included if(self.isOnlyOSDiskBackupEnabled == True): if(mountpoint == '/'): total_used = total_used + int(used) self.logger.log("Adding only root device to size calculation. Device name : {0} used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) self.logger.log("Total Used Space: {0}".format(total_used),True) #Handling a case where LunList is empty for UnmanagedVM's and failures if occurred( as we will billing for all the non resource disks) elif( (self.size_calc_failed == True or self.LunListEmpty == True) and device != resource_disk_device): self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB #LunList is empty but the device is an actual temporary disk so excluding it elif( (self.size_calc_failed == True or self.LunListEmpty == True) and device == resource_disk_device): self.logger.log("Device {0} is not included for billing as it is a resource disk, used space in KB : {1} mount point : {2} fstype :{3}".format(device,used,mountpoint,fstype),True) excluded_disks_used = excluded_disks_used + int(used) #Including only the disks which are asked to include (Here LunList can't be empty this case is handled at the CRP end) else: if self.isAnyDiskExcluded == False and device != resource_disk_device: #No disk has been excluded So can include every non resource disk self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB elif self.isAnyDiskExcluded == False and device == resource_disk_device: #excluding resource disk even in the case where all disks are included as it is the actual temporary disk self.logger.log("Device {0} is not included for billing as it is a resource disk, used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) excluded_disks_used = excluded_disks_used + int(used) elif mountpoint in self.device_mount_points and device != resource_disk_device: self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB elif device != resource_disk_device and -1 in self.includedLunList: if mountpoint in self.root_mount_points : self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB else: #check for logicalVolumes templgv = device.split("-") if(len(templgv) > 1 and templgv[0] in self.logicalVolume_to_bill): self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB else: self.logger.log("Device {0} is not included for billing as it is not part of the disks to be included, used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) excluded_disks_used = excluded_disks_used + int(used) else: # check for logicalVolumes even if os disk is not included templgv = device.split("-") if(len(templgv) > 1 and templgv[0] in self.logicalVolume_to_bill): self.logger.log("Adding Device name : {0} for billing used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) total_used = total_used + int(used) #return in KB else: self.logger.log("Device {0} is not included for billing as it is not part of the disks to be included, used space in KB : {1} mount point : {2} fstype : {3}".format(device,used,mountpoint,fstype),True) excluded_disks_used = excluded_disks_used + int(used) if not (isKnownFs or fstype == '' or fstype == None): total_used_unknown_fs = total_used_unknown_fs + int(used) index = index + 1 if not len(unknown_fs_types) == 0: Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("unknownFSTypeInDf",str(unknown_fs_types)) Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("totalUsedunknownFS",str(total_used_unknown_fs)) self.logger.log("Total used space in Bytes of unknown FSTypes : {0}".format(total_used_unknown_fs * 1024),True) if total_used_temporary_disks != actual_temp_disk_used : self.logger.log("Billing differenct because of incorrect temp disk: {0}".format(str(total_used_temporary_disks - actual_temp_disk_used))) if not len(network_fs_types) == 0: Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("networkFSTypeInDf",str(network_fs_types)) Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("totalUsedNetworkShare",str(total_used_network_shares)) self.logger.log("Total used space in Bytes of network shares : {0}".format(total_used_network_shares * 1024),True) if total_used_gluster !=0 : Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("glusterFSSize",str(total_used_gluster)) if total_used_temporary_disks !=0: Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("tempDisksSize",str(total_used_temporary_disks)) if total_used_ram_disks != 0: Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("ramDisksSize",str(total_used_ram_disks)) if total_used_loop_device != 0 : Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("loopDevicesSize",str(total_used_loop_device)) totalSpaceUsed = total_used + excluded_disks_used self.logger.log("TotalUsedSpace ( both included and excluded disks ) in Bytes : {0} , TotalUsedSpaceAfterExcludeLUN in Bytes : {1} , TotalLUNExcludedUsedSpace in Bytes : {2} ".format(totalSpaceUsed *1024 , total_used * 1024 , excluded_disks_used *1024 ),True) if total_sd_size != 0 : Utils.HandlerUtil.HandlerUtility.add_to_telemetery_data("totalsdSize",str(total_sd_size)) self.logger.log("Total sd* used space in Bytes : {0}".format(total_sd_size * 1024),True) self.logger.log("SizeComputationFailedFlag {0}".format(self.size_calc_failed)) return total_used * 1024,self.size_calc_failed #Converting into Bytes except Exception as e: errMsg = 'Unable to fetch total used space with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg,True) self.size_calc_failed = True return 0,self.size_calc_failed ================================================ FILE: VMBackup/main/Utils/Status.py ================================================ import json class TopLevelStatus: def __init__(self, version, timestampUTC, status): self.version = version self.timestampUTC = timestampUTC self.status = status def convertToDictionary(self): return dict(version = self.version, timestampUTC = self.timestampUTC, status = self.status) class StatusObj: def __init__(self, name, operation, status, substatus, code, formattedMessage, telemetrydata, storageDetails, uniqueMachineId, taskId, commandStartTimeUTCTicks, snapshotInfo, vmHealthInfo): self.name = name self.operation = operation self.status = status self.substatus = substatus self.code = code self.formattedMessage = formattedMessage self.telemetryData = telemetrydata self.storageDetails = storageDetails self.uniqueMachineId = uniqueMachineId self.taskId = taskId self.commandStartTimeUTCTicks = commandStartTimeUTCTicks self.snapshotInfo = snapshotInfo self.vmHealthInfo = vmHealthInfo def convertToDictionary(self): return dict(name = self.name, operation = self.operation, status = self.status, substatus = self.substatus, code = self.code, taskId = self.taskId, formattedMessage = self.formattedMessage, storageDetails = self.storageDetails, commandStartTimeUTCTicks = self.commandStartTimeUTCTicks, telemetryData = self.telemetryData, uniqueMachineId = self.uniqueMachineId, snapshotInfo = self.snapshotInfo, vmHealthInfo = self.vmHealthInfo) class VmHealthInfoObj: def __init__(self, vmHealthState, vmHealthStatusCode): self.vmHealthState = vmHealthState self.vmHealthStatusCode = vmHealthStatusCode def convertToDictionary(self): return dict(vmHealthState = self.vmHealthState,vmHealthStatusCode = self.vmHealthStatusCode) class SubstatusObj: def __init__(self, code, name, status, formattedMessage): self.code = code self.name = name self.status = status self.formattedMessage = formattedMessage def convertToDictionary(self): return dict(code = self.code, name = self.name, status = self.status, formattedMessage = self.formattedMessage) class StorageDetails: def __init__(self, partitionCount, totalUsedSizeInBytes, isStoragespacePresent, isSizeComputationFailed): self.partitionCount = partitionCount self.totalUsedSizeInBytes = totalUsedSizeInBytes self.isStoragespacePresent = isStoragespacePresent self.isSizeComputationFailed = isSizeComputationFailed def convertToDictionary(self): return dict(partitionCount = self.partitionCount, totalUsedSizeInBytes = self.totalUsedSizeInBytes, isStoragespacePresent = self.isStoragespacePresent, isSizeComputationFailed = self.isSizeComputationFailed) class SnapshotInfoObj: def __init__(self, isSuccessful, snapshotUri, errorMessage, blobUri, directDriveSnapshotIdentifier = None): self.isSuccessful = isSuccessful self.snapshotUri = snapshotUri # snapshotUri is populated only for XStore disks (will be None for DD disks) self.errorMessage = errorMessage self.blobUri = blobUri # blobUri is populated for both XStore and DD disks (this is base blobUri, NOT snapshotUri) self.directDriveSnapshotIdentifier = directDriveSnapshotIdentifier # This is populated only for DD disks (will be None for XStore disks) def convertToDictionary(self): return dict(isSuccessful = self.isSuccessful, snapshotUri = self.snapshotUri, errorMessage = self.errorMessage, blobUri = self.blobUri, directDriveSnapshotIdentifier = self.directDriveSnapshotIdentifier) class DirectDriveSnapshotIdentifier: def __init__(self, creationTime, id, token, instantAccessDurationMinutes = None): self.creationTime = creationTime self.id = id self.token = token self.instantAccessDurationMinutes = instantAccessDurationMinutes # This is populated for DD disk with Instant Access snapshot def convertToDictionary(self): return dict(creationTime = self.creationTime, id = self.id, token = self.token, instantAccessDurationMinutes = self.instantAccessDurationMinutes) class CreationTime: def __init__(self, DateTime, OffsetMinutes): self.DateTime = DateTime self.OffsetMinutes = OffsetMinutes def convertToDictionary(self): return dict(DateTime = self.DateTime, OffsetMinutes = self.OffsetMinutes) class FormattedMessage: def __init__(self, lang, message): self.lang = lang self.message = message class ExtVmHealthStateEnum(): green = 0 yellow = 128 red = 256 class SnapshotConsistencyType(): none = 0 fileSystemConsistent = 1 applicationConsistent = 2 crashConsistent = 3 class ExtensionResponse: def __init__(self, messageStr, snapshotConsistency, jobMessage, failure_flag): self.messageStr = messageStr self.snapshotConsistency = snapshotConsistency self.jobMessage = jobMessage self.failure_flag = failure_flag def convertToDictionary(self): return dict(messageStr = self.messageStr, snapshotConsistency = self.snapshotConsistency, jobMessage = self.jobMessage, isSizeComputationFailed = self.failure_flag) ================================================ FILE: VMBackup/main/Utils/StringHelper.py ================================================ import datetime class StringHelper: def resolve_string(self,severity_level, message): try: msg_body = datetime.datetime.utcnow().isoformat() + "\t" + "[" + severity_level + "]:\t" if message and message.strip(): msg_body += message + " " msg_body += "\n" return msg_body except Exception as e: pass ================================================ FILE: VMBackup/main/Utils/WAAgentUtil.py ================================================ # Wrapper module for waagent # # waagent is not written as a module. This wrapper module is created # to use the waagent code as a module. # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. try: # For Python 3.5 and later, use importlib import importlib.util has_importlib_util = True except ImportError: has_importlib_util = False try: import imp as imp has_imp = True except ImportError: has_imp = False if not has_importlib_util and not has_imp: raise ImportError("Neither importlib.util nor imp module is available") import os import os.path # # The following code will search and load waagent code and expose # it as a submodule of current module # def searchWAAgent(): agentPath = os.path.join(os.getcwd(), "main/WaagentLib.py") if(os.path.isfile(agentPath)): return agentPath user_paths = os.environ['PYTHONPATH'].split(os.pathsep) for user_path in user_paths: agentPath = os.path.join(user_path, 'waagent') if(os.path.isfile(agentPath)): return agentPath return None def searchWAAgentOld(): agentPath = '/usr/sbin/waagent' if(os.path.isfile(agentPath)): return agentPath user_paths = os.environ['PYTHONPATH'].split(os.pathsep) for user_path in user_paths: agentPath = os.path.join(user_path, 'waagent') if(os.path.isfile(agentPath)): return agentPath return None pathUsed = 1 try: agentPath = searchWAAgent() if agentPath is None: pathUsed = 0 # Search for the old agent path if the new one is not found agentPath = searchWAAgentOld() if agentPath: if has_importlib_util: # For Python 3.5 and later, use importlib spec = importlib.util.spec_from_file_location('waagent', agentPath) waagent = importlib.util.module_from_spec(spec) spec.loader.exec_module(waagent) elif has_imp: # For Python 3.4 and earlier, use imp module waagent = imp.load_source('waagent', agentPath) else: raise Exception("No suitable import mechanism available.") else: raise Exception("Can't load new or old waagent. Agent path not found.") except Exception as e: raise Exception(str(e)) if not hasattr(waagent, "AddExtensionEvent"): """ If AddExtensionEvent is not defined, provide a dummy impl. """ def _AddExtensionEvent(*args, **kwargs): pass waagent.AddExtensionEvent = _AddExtensionEvent if not hasattr(waagent, "WALAEventOperation"): class _WALAEventOperation: HeartBeat = "HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" waagent.WALAEventOperation = _WALAEventOperation __ExtensionName__ = None def InitExtensionEventLog(name): __ExtensionName__ = name def AddExtensionEvent(name=__ExtensionName__, op=waagent.WALAEventOperation.Enable, isSuccess=False, message=None): if name is not None: waagent.AddExtensionEvent(name=name, op=op, isSuccess=isSuccess, message=message) def GetPathUsed(): return pathUsed ================================================ FILE: VMBackup/main/Utils/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMBackup/main/Utils/dhcpUtils.py ================================================ # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ and Openssl 1.0+ import sys import platform import os import subprocess import socket import array import struct from uuid import getnode as get_mac import json # Utils for sample.py def make_address(ip_address): ''' Returns the address for scheduled event endpoint from the IP address provided. ''' return 'http://' + ip_address + '/metadata/scheduledevents?api-version=2017-03-01' def check_ip_address(address, headers): ''' Checks whether the address of the scheduled event endpoint is valid. ''' try: response = get_scheduled_events(address, headers) return 'Events' in json.loads(response.read().decode('utf-8')) except (urllib.error.URLError, UnicodeDecodeError, json.decoder.JSONDecodeError) as _: return False def get_ip_address_reg_env(use_registry): ''' Get the IP address of scheduled event from registry or environment. Returns None if IP address is not provided or stored. ''' ip_address = None if use_registry and sys.platform == 'win32': import winreg as wreg # use ScheduledEventsIp in registry if available. try: key = wreg.OpenKey(wreg.HKEY_LOCAL_MACHINE, "Software\\ScheduledEvents") ip_address = wreg.QueryValueEx(key, 'ScheduledEventsIp')[0] key.Close() except FileNotFoundError: pass elif sys.platform == 'win32' or "linux" in sys.platform: # use SCHEDULEDEVENTSIP in system variables if available. ip_address = os.getenv('SCHEDULEDEVENTSIP') return ip_address # Utils for discovery.py def unpack(buf, offset, range): """ Unpack bytes into python values. """ result = 0 for i in range: result = (result << 8) | str_to_ord(buf[offset + i]) return result def unpack_big_endian(buf, offset, length): """ Unpack big endian bytes into python values. """ return unpack(buf, offset, list(range(0, length))) def hex_dump3(buf, offset, length): """ Dump range of buf in formatted hex. """ return ''.join(['%02X' % str_to_ord(char) for char in buf[offset:offset + length]]) def hex_dump2(buf): """ Dump buf in formatted hex. """ return hex_dump3(buf, 0, len(buf)) def hex_dump(buffer, size): """ Return Hex formated dump of a 'buffer' of 'size'. """ if size < 0: size = len(buffer) result = "" for i in range(0, size): if (i % 16) == 0: result += "%06X: " % i byte = buffer[i] if type(byte) == str: byte = ord(byte.decode('latin1')) result += "%02X " % byte if (i & 15) == 7: result += " " if ((i + 1) % 16) == 0 or (i + 1) == size: j = i while ((j + 1) % 16) != 0: result += " " if (j & 7) == 7: result += " " j += 1 result += " " for j in range(i - (i % 16), i + 1): byte = buffer[j] if type(byte) == str: byte = str_to_ord(byte.decode('latin1')) k = '.' if is_printable(byte): k = chr(byte) result += k if (i + 1) != size: result += "\n" return result def str_to_ord(a): """ Allows indexing into a string or an array of integers transparently. Generic utility function. """ if type(a) == type(b'') or type(a) == type(u''): a = ord(a) return a def compare_bytes(a, b, start, length): for offset in range(start, start + length): if str_to_ord(a[offset]) != str_to_ord(b[offset]): return False return True def int_to_ip4_addr(a): """ Build DHCP request string. """ return "%u.%u.%u.%u" % ((a >> 24) & 0xFF, (a >> 16) & 0xFF, (a >> 8) & 0xFF, (a) & 0xFF) def hexstr_to_bytearray(a): """ Return hex string packed into a binary struct. """ b = b"" for c in range(0, len(a) // 2): b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16)) return b def is_printable(ch): """ Return True if character is displayable. """ return (is_in_range(ch, str_to_ord('A'), str_to_ord('Z')) or is_in_range(ch, str_to_ord('a'), str_to_ord('z')) or is_in_range(ch, str_to_ord('0'), str_to_ord('9'))) def is_in_range(a, low, high): """ Return True if 'a' in 'low' <= a >= 'high' """ return (a >= low and a <= high) if not hasattr(subprocess,'check_output'): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, ' 'it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return ("Command '{0}' returned non-zero exit status {1}" "").format(self.cmd, self.returncode) subprocess.check_output=check_output subprocess.CalledProcessError=CalledProcessError """ Shell command util functions """ def run(cmd, chk_err=True): """ Calls run_get_output on 'cmd', returning only the return code. If chk_err=True then errors will be reported in the log. If chk_err=False then errors will be suppressed from the log. """ retcode,out=run_get_output(cmd,chk_err) return retcode def run_get_output(cmd, chk_err=True, log_cmd=False): """ Wrapper for subprocess.check_output. Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ try: output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) output = ustr(output, encoding='utf-8', errors="backslashreplace") except subprocess.CalledProcessError as e : output = ustr(e.output, encoding='utf-8', errors="backslashreplace") return e.returncode, output return 0, output # End shell command util functions class DefaultOSUtil(object): def __init__(self, logger): self.logger = logger def get_mac_in_bytes(self): mac = get_mac() machex = '%012x' % mac try: macb = bytearray.fromhex(machex) except TypeError: # Work-around for Python 2.6 bug macb = bytearray.fromhex(unicode(machex)) return macb def allow_dhcp_broadcast(self): #Open DHCP port if iptables is enabled. # We supress error logging on error. run("iptables -D INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) run("iptables -I INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) def get_first_if(self): """ Return the interface name, and ip addr of the first active non-loopback interface. """ iface='' expected=16 # how many devices should I expect... struct_size=40 # for 64bit the size is 40 bytes sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) buff=array.array('B', b'\0' * (expected * struct_size)) param = struct.pack('iL', expected*struct_size, buff.buffer_info()[0]) ret = fcntl.ioctl(sock.fileno(), 0x8912, param) retsize=(struct.unpack('iL', ret)[0]) if retsize == (expected * struct_size): self.logger.log(('SIOCGIFCONF returned more than {0} up network interfaces.').format(expected)) sock = buff.tostring() primary = bytearray(self.get_primary_interface(), encoding='utf-8') for i in range(0, struct_size * expected, struct_size): iface=sock[i:i+16].split(b'\0', 1)[0] if len(iface) == 0 or self.is_loopback(iface) or iface != primary: # test the next one self.logger.log('interface [{0}] skipped'.format(iface)) continue else: # use this one self.logger.log('interface [{0}] selected'.format(iface)) break return iface.decode('latin-1'), socket.inet_ntoa(sock[i+20:i+24]) def get_primary_interface(self): """ Get the name of the primary interface, which is the one with the default route attached to it; if there are multiple default routes, the primary has the lowest Metric. :return: the interface which has the default route """ # from linux/route.h RTF_GATEWAY = 0x02 DEFAULT_DEST = "00000000" hdr_iface = "Iface" hdr_dest = "Destination" hdr_flags = "Flags" hdr_metric = "Metric" idx_iface = -1 idx_dest = -1 idx_flags = -1 idx_metric = -1 primary = None primary_metric = None self.logger.log("examine /proc/net/route for primary interface") with open('/proc/net/route') as routing_table: idx = 0 for header in list(filter(lambda h: len(h) > 0, routing_table.readline().strip(" \n").split("\t"))): if header == hdr_iface: idx_iface = idx elif header == hdr_dest: idx_dest = idx elif header == hdr_flags: idx_flags = idx elif header == hdr_metric: idx_metric = idx idx = idx + 1 for entry in routing_table.readlines(): route = entry.strip(" \n").split("\t") if route[idx_dest] == DEFAULT_DEST and int(route[idx_flags]) & RTF_GATEWAY == RTF_GATEWAY: metric = int(route[idx_metric]) iface = route[idx_iface] if primary is None or metric < primary_metric: primary = iface primary_metric = metric if primary is None: primary = '' self.logger.log('primary interface is [{0}]'.format(primary)) return primary def is_loopback(self, ifname): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) result = fcntl.ioctl(s.fileno(), 0x8913, struct.pack('256s', ifname[:15])) flags, = struct.unpack('H', result[16:18]) isloopback = flags & 8 == 8 self.logger.log('interface [{0}] has flags [{1}], is loopback [{2}]'.format(ifname, flags, isloopback)) return isloopback def get_ip4_addr(self): return self.get_first_if()[1] def start_network(self): pass def route_add(self, net, mask, gateway): """ Add specified route using /sbin/route add -net. """ cmd = ("/sbin/route add -net " "{0} netmask {1} gw {2}").format(net, mask, gateway) return run(cmd, chk_err=False) """ Add alies for python2 and python3 libs and fucntions. """ if sys.version_info[0]== 3: import http.client as httpclient from urllib.parse import urlparse """Rename Python3 str to ustr""" ustr = str bytebuffer = memoryview read_input = input elif sys.version_info[0] == 2: import httplib as httpclient from urlparse import urlparse """Rename Python2 unicode to ustr""" ustr = unicode bytebuffer = buffer read_input = raw_input else: raise ImportError("Unknown python version:{0}".format(sys.version_info)) ================================================ FILE: VMBackup/main/VMSnapshotPluginHost.conf ================================================ [pre_post] timeoutInSeconds: 1800 numberOfPlugins: 1 pluginName0: ScriptRunner pluginPath0: /etc/azure pluginConfigPath0: /etc/azure/VMSnapshotScriptPluginConfig.json ================================================ FILE: VMBackup/main/WaagentLib.py ================================================ #!/usr/bin/env python # # Azure Linux Agent # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.6+ and Openssl 1.0+ # # Implements parts of RFC 2131, 1541, 1497 and # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx # http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx # # TODO: Many classes, methods, and imports in this file might not be needed by VM Backup Extension # and should be removed to reduce file size and eliminate unnecessary dependencies. # Future cleanup should analyze actual VMBackup usage and remove unused code. # Note: crypt module deprecated in Python 3.13+, but gen_password_hash() is not used by VMBackup try: import crypt except ImportError: # Python 3.13+ removed crypt module, but VMBackup doesn't use password functions crypt = None import random import base64 try: import httplib as httplibs except ImportError: import http.client as httplibs import os import os.path import platform import pwd import re import shutil import socket try: import SocketServer as SocketServers except ImportError: import socketserver as SocketServers import string import subprocess import sys import tempfile import textwrap import threading import time import traceback import xml.dom.minidom import inspect import zipfile import json import datetime import xml.sax.saxutils try: from packaging.version import Version as LooseVersion except ImportError: try: from distutils.version import LooseVersion except ImportError: # Fallback for environments without packaging or distutils class LooseVersion: """ Custom version comparison class that implements semantic versioning. Examples of version comparisons that work correctly: - LooseVersion("10.0") > LooseVersion("2.0") # True (10 > 2, not string "10.0" < "2.0") - LooseVersion("1.10") > LooseVersion("1.2") # True (10 > 2 in minor version) - LooseVersion("2.1.3") > LooseVersion("2.1") # True (2.1.3 > 2.1.0) - LooseVersion("1.0-alpha") < LooseVersion("1.0") # True (pre-release < release) - LooseVersion("1.0-beta") > LooseVersion("1.0-alpha") # True (beta > alpha) - LooseVersion("1.0-rc") > LooseVersion("1.0-beta") # True (rc > beta) How parsing works: - "2.1.3" → (2, 1, 3) - "1.0-alpha" → (1, 0, -1000) # alpha = -1000 for correct precedence - "1.0-beta" → (1, 0, -100) # beta = -100 - "1.0-rc" → (1, 0, -10) # rc = -10 - "1.0" → (1, 0) # release version (no negative suffix) Tuple comparison ensures: (1, 0, -1000) < (1, 0, -100) < (1, 0, -10) < (1, 0) """ def __init__(self, version_string): self.version = str(version_string) # Parse version into comparable parts self._parsed = self._parse_version(self.version) def _parse_version(self, version_str): """ Parse version string into comparable tuple of integers and strings. Parsing examples: - "2.1.3" → splits to ["2", "1", "3"] → converts to (2, 1, 3) - "1.0-alpha" → splits to ["1", "0", "alpha"] → converts to (1, 0, -1000) - "1.10.5-beta2" → splits to ["1", "10", "5", "beta2"] → converts to (1, 10, 5, "beta2") """ import re # Split by dots, hyphens, and underscores parts = re.split(r'[.\-_]', version_str.lower()) parsed = [] for part in parts: # Try to convert to int, otherwise keep as string try: parsed.append(int(part)) except ValueError: # Handle pre-release identifiers with negative values for correct precedence # This ensures: alpha < beta < rc < release if part in ('alpha', 'a'): parsed.append(-1000) # Lowest precedence elif part in ('beta', 'b'): parsed.append(-100) # Medium precedence elif part in ('rc', 'pre'): parsed.append(-10) # High precedence (but still < release) else: parsed.append(part) # Keep as string for mixed alphanumeric return tuple(parsed) def __str__(self): return self.version def __eq__(self, other): if isinstance(other, LooseVersion): return self._parsed == other._parsed return self._parsed == LooseVersion(other)._parsed def __lt__(self, other): if isinstance(other, LooseVersion): return self._parsed < other._parsed return self._parsed < LooseVersion(other)._parsed def __le__(self, other): if isinstance(other, LooseVersion): return self._parsed <= other._parsed return self._parsed <= LooseVersion(other)._parsed def __gt__(self, other): if isinstance(other, LooseVersion): return self._parsed > other._parsed return self._parsed > LooseVersion(other)._parsed def __ge__(self, other): if isinstance(other, LooseVersion): return self._parsed >= other._parsed return self._parsed >= LooseVersion(other)._parsed if not hasattr(subprocess, 'check_output'): def check_output(*popenargs, **kwargs): r"""Backport from subprocess module from python 2.7""" if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: cmd = kwargs.get("args") if cmd is None: cmd = popenargs[0] raise subprocess.CalledProcessError(retcode, cmd, output=output) return output # Exception classes used by this module. class CalledProcessError(Exception): def __init__(self, returncode, cmd, output=None): self.returncode = returncode self.cmd = cmd self.output = output def __str__(self): return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) subprocess.check_output = check_output subprocess.CalledProcessError = CalledProcessError GuestAgentName = "WALinuxAgent" GuestAgentLongName = "Azure Linux Agent" GuestAgentVersion = "WALinuxAgent-2.0.16" ProtocolVersion = "2012-11-30" # WARNING this value is used to confirm the correct fabric protocol. Config = None WaAgent = None DiskActivated = False Openssl = "openssl" Children = [] ExtensionChildren = [] VMM_STARTUP_SCRIPT_NAME = 'install' VMM_CONFIG_FILE_NAME = 'linuxosconfiguration.xml' global RulesFiles RulesFiles = ["/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules"] VarLibDhcpDirectories = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"] EtcDhcpClientConfFiles = ["/etc/dhcp/dhclient.conf", "/etc/dhcp3/dhclient.conf"] global LibDir LibDir = "/var/lib/waagent" global provisioned provisioned = False global provisionError provisionError = None HandlerStatusToAggStatus = {"installed": "Installing", "enabled": "Ready", "unintalled": "NotReady", "disabled": "NotReady"} WaagentConf = """\ # # Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ext4 # Typically ext3 or ext4. FreeBSD images should use 'ufs2' here. ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ README_FILENAME = "DATALOSS_WARNING_README.txt" README_FILECONTENT = """\ WARNING: THIS IS A TEMPORARY DISK. Any data stored on this drive is SUBJECT TO LOSS and THERE IS NO WAY TO RECOVER IT. Please do not use this disk for storing any personal or application data. For additional details to please refer to the MSDN documentation at : http://msdn.microsoft.com/en-us/library/windowsazure/jj672979.aspx """ ############################################################ # BEGIN DISTRO CLASS DEFS ############################################################ ############################################################ # AbstractDistro ############################################################ class AbstractDistro(object): """ AbstractDistro defines a skeleton neccesary for a concrete Distro class. Generic methods and attributes are kept here, distribution specific attributes and behavior are to be placed in the concrete child named distroDistro, where distro is the string returned by calling python platform.linux_distribution()[0]. So for CentOS the derived class is called 'centosDistro'. """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux = None self.service_cmd = '/usr/sbin/service' self.ssh_service_restart_option = 'restart' self.ssh_service_name = 'ssh' self.ssh_config_file = '/etc/ssh/sshd_config' self.hostname_file_path = '/etc/hostname' self.dhcp_client_name = 'dhclient' self.requiredDeps = ['route', 'shutdown', 'ssh-keygen', 'useradd', 'usermod', 'openssl', 'sfdisk', 'fdisk', 'mkfs', 'sed', 'grep', 'sudo', 'parted'] self.init_script_file = '/etc/init.d/waagent' self.agent_package_name = 'WALinuxAgent' self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log", '/etc/resolv.conf'] self.agent_files_to_uninstall = ["/etc/waagent.conf", "/etc/logrotate.d/waagent"] self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX_DEFAULT=' self.getpidcmd = 'pidof' self.mount_dvd_cmd = 'mount' self.sudoers_dir_base = '/etc' self.waagent_conf_file = WaagentConf self.shadow_file_mode = 0o600 self.shadow_file_path = "/etc/shadow" self.dhcp_enabled = False def isSelinuxSystem(self): """ Checks and sets self.selinux = True if SELinux is available on system. """ if self.selinux == None: if Run("which getenforce", chk_err=False): self.selinux = False else: self.selinux = True return self.selinux def isSelinuxRunning(self): """ Calls shell command 'getenforce' and returns True if 'Enforcing'. """ if self.isSelinuxSystem(): return RunGetOutput("getenforce")[1].startswith("Enforcing") else: return False def setSelinuxEnforce(self, state): """ Calls shell command 'setenforce' with 'state' and returns resulting exit code. """ if self.isSelinuxSystem(): if state: s = '1' else: s = '0' return Run("setenforce " + s) def setSelinuxContext(self, path, cn): """ Calls shell 'chcon' with 'path' and 'cn' context. Returns exit result. """ if self.isSelinuxSystem(): if not os.path.exists(path): Error("Path does not exist: {0}".format(path)) return 1 return Run('chcon ' + cn + ' ' + path) def setHostname(self, name): """ Shell call to hostname. Returns resulting exit code. """ return Run('hostname ' + name) def publishHostname(self, name): """ Set the contents of the hostname file to 'name'. Return 1 on failure. """ try: r = SetFileContents(self.hostname_file_path, name) for f in EtcDhcpClientConfFiles: if os.path.exists(f) and FindStringInFile(f, r'^[^#]*?send\s*host-name.*?(|gethostname[(,)])') == None: r = ReplaceFileContentsAtomic('/etc/dhcp/dhclient.conf', "send host-name \"" + name + "\";\n" + "\n".join(list(filter(lambda a: not a.startswith("send host-name"), GetFileContents('/etc/dhcp/dhclient.conf').split( '\n'))))) except: return 1 return r def installAgentServiceScriptFiles(self): """ Create the waagent support files for service installation. Called by registerAgentService() Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def registerAgentService(self): """ Calls installAgentService to create service files. Shell exec service registration commands. (e.g. chkconfig --add waagent) Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def uninstallAgentService(self): """ Call service subsystem to remove waagent script. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' start') def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' ' + self.agent_service_name + ' stop', False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_name + " " + self.ssh_service_restart_option retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def checkPackageInstalled(self, p): """ Query package database for prescence of an installed package. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def checkPackageUpdateable(self, p): """ Online check if updated package of walinuxagent is available. Abstract Virtual Function. Over-ridden in concrete Distro classes. """ pass def deleteRootPassword(self): """ Generic root password removal. """ filepath = "/etc/shadow" ReplaceFileContentsAtomic(filepath, "root:*LOCK*:14600::::::\n" + "\n".join( list(filter(lambda a: not a.startswith("root:"), GetFileContents(filepath).split('\n'))))) os.chmod(filepath, self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath, 'system_u:object_r:shadow_t:s0') Log("Root password deleted.") return 0 def changePass(self, user, password): Log("Change user password") crypt_id = Config.get("Provisioning.PasswordCryptId") if crypt_id is None: crypt_id = "6" salt_len = Config.get("Provisioning.PasswordCryptSaltLength") try: salt_len = int(salt_len) if salt_len < 0 or salt_len > 10: salt_len = 10 except (ValueError, TypeError): salt_len = 10 return self.chpasswd(user, password, crypt_id=crypt_id, salt_len=salt_len) def chpasswd(self, username, password, crypt_id=6, salt_len=10): passwd_hash = self.gen_password_hash(password, crypt_id, salt_len) cmd = "usermod -p '{0}' {1}".format(passwd_hash, username) ret, output = RunGetOutput(cmd, log_cmd=False) if ret != 0: return "Failed to set password for {0}: {1}".format(username, output) def gen_password_hash(self, password, crypt_id, salt_len): if crypt is None: raise ImportError("crypt module not available (Python 3.13+). This function is not used by VMBackup.") collection = string.ascii_letters + string.digits salt = ''.join(random.choice(collection) for _ in range(salt_len)) salt = "${0}${1}".format(crypt_id, salt) return crypt.crypt(password, salt) def load_ata_piix(self): return WaAgent.TryLoadAtapiix() def unload_ata_piix(self): """ Generic function to remove ata_piix.ko. """ return WaAgent.TryUnloadAtapiix() def deprovisionWarnUser(self): """ Generic user warnings used at deprovision. """ print("WARNING! Nameserver configuration in /etc/resolv.conf will be deleted.") def deprovisionDeleteFiles(self): """ Files to delete when VM is deprovisioned """ for a in VarLibDhcpDirectories: Run("rm -f " + a + "/*") # Clear LibDir, remove nameserver and root bash history for f in os.listdir(LibDir) + self.fileBlackList: try: os.remove(f) except: pass return 0 def uninstallDeleteFiles(self): """ Files to delete when agent is uninstalled. """ for f in self.agent_files_to_uninstall: try: os.remove(f) except: pass return 0 def checkDependencies(self): """ Generic dependency check. Return 1 unless all dependencies are satisfied. """ if self.checkPackageInstalled('NetworkManager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 try: m = __import__('pyasn1') except ImportError: Error(GuestAgentLongName + " requires python-pyasn1 for your Linux distribution.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self, buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot + '/etc'): os.mkdir(buildroot + '/etc') SetFileContents(buildroot + '/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot + '/etc/logrotate.d'): os.mkdir(buildroot + '/etc/logrotate.d') SetFileContents(buildroot + '/etc/logrotate.d/waagent', WaagentLogrotate) self.init_script_file = buildroot + self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def RestartInterface(self, iface, max_retry=3): for retry in range(1, max_retry + 1): ret = Run("ifdown " + iface + " && ifup " + iface) if ret == 0: return Log("Failed to restart interface: {0}, ret={1}".format(iface, ret)) if retry < max_retry: Log("Retry restart interface in 5 seconds") time.sleep(5) def CreateAccount(self, user, password, expiration, thumbprint): return CreateAccount(user, password, expiration, thumbprint) def DeleteAccount(self, user): return DeleteAccount(user) def Install(self): return Install() def mediaHasFilesystem(self, dsk): if len(dsk) == 0: return False if Run("LC_ALL=C fdisk -l " + dsk + " | grep Disk"): return False return True def mountDVD(self, dvd, location): return RunGetOutput(self.mount_dvd_cmd + ' ' + dvd + ' ' + location) def GetHome(self): return GetHome() def getDhcpClientName(self): return self.dhcp_client_name def initScsiDiskTimeout(self): """ Set the SCSI disk timeout when the agent starts running """ self.setScsiDiskTimeout() def setScsiDiskTimeout(self): """ Iterate all SCSI disks(include hot-add) and set their timeout if their value are different from the OS.RootDeviceScsiTimeout """ try: scsiTimeout = Config.get("OS.RootDeviceScsiTimeout") for diskName in [disk for disk in os.listdir("/sys/block") if disk.startswith("sd")]: self.setBlockDeviceTimeout(diskName, scsiTimeout) except: pass def setBlockDeviceTimeout(self, device, timeout): """ Set SCSI disk timeout by set /sys/block/sd*/device/timeout """ if timeout != None and device: filePath = "/sys/block/" + device + "/device/timeout" if (GetFileContents(filePath).splitlines()[0].rstrip() != timeout): SetFileContents(filePath, timeout) Log("SetBlockDeviceTimeout: Update the device " + device + " with timeout " + timeout) def waitForSshHostKey(self, path): """ Provide a dummy waiting, since by default, ssh host key is created by waagent and the key should already been created. """ if (os.path.isfile(path)): return True else: Error("Can't find host key: {0}".format(path)) return False def isDHCPEnabled(self): return self.dhcp_enabled def stopDHCP(self): """ Stop the system DHCP client so that the agent can bind on its port. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('stopDHCP method missing') def startDHCP(self): """ Start the system DHCP client. If the distro has set dhcp_enabled to True, it will need to provide an implementation of this method. """ raise NotImplementedError('startDHCP method missing') def translateCustomData(self, data): """ Translate the custom data from a Base64 encoding. Default to no-op. """ decodeCustomData = Config.get("Provisioning.DecodeCustomData") if decodeCustomData != None and decodeCustomData.lower().startswith("y"): return base64.b64decode(data) return data def getConfigurationPath(self): return "/etc/waagent.conf" def getProcessorCores(self): return int(RunGetOutput("grep 'processor.*:' /proc/cpuinfo |wc -l")[1]) def getTotalMemory(self): return int(RunGetOutput("grep MemTotal /proc/meminfo |awk '{print $2}'")[1]) / 1024 def getInterfaceNameByMac(self, mac): ret, output = RunGetOutput("ifconfig -a") if ret != 0: raise Exception("Failed to get network interface info") output = output.replace('\n', '') match = re.search(r"(eth\d).*(HWaddr|ether) {0}".format(mac), output, re.IGNORECASE) if match is None: raise Exception("Failed to get ifname with mac: {0}".format(mac)) output = match.group(0) eths = re.findall(r"eth\d", output) if eths is None or len(eths) == 0: raise Exception("Failed to get ifname with mac: {0}".format(mac)) return eths[-1] def configIpV4(self, ifName, addr, netmask=24): ret, output = RunGetOutput("ifconfig {0} up".format(ifName)) if ret != 0: raise Exception("Failed to bring up {0}: {1}".format(ifName, output)) ret, output = RunGetOutput("ifconfig {0} {1}/{2}".format(ifName, addr, netmask)) if ret != 0: raise Exception("Failed to config ipv4 for {0}: {1}".format(ifName, output)) def setDefaultGateway(self, gateway): Run("/sbin/route add default gw" + gateway, chk_err=False) def routeAdd(self, net, mask, gateway): Run("/sbin/route add -net " + net + " netmask " + mask + " gw " + gateway, chk_err=False) ############################################################ # GentooDistro ############################################################ gentoo_init_file = """\ #!/sbin/runscript command=/usr/sbin/waagent pidfile=/var/run/waagent.pid command_args=-daemon command_background=true name="Azure Linux Agent" depend() { need localmount use logger network after bootmisc modules } """ class gentooDistro(AbstractDistro): """ Gentoo distro concrete class """ def __init__(self): # super(gentooDistro, self).__init__() self.service_cmd = '/sbin/service' self.ssh_service_name = 'sshd' self.hostname_file_path = '/etc/conf.d/hostname' self.dhcp_client_name = 'dhcpcd' self.shadow_file_mode = 0o640 self.init_file = gentoo_init_file def publishHostname(self, name): try: if (os.path.isfile(self.hostname_file_path)): r = ReplaceFileContentsAtomic(self.hostname_file_path, "hostname=\"" + name + "\"\n" + "\n".join(list(filter(lambda a: not a.startswith("hostname="), GetFileContents(self.hostname_file_path).split("\n"))))) except: return 1 return r def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o755) def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('rc-update add ' + self.agent_service_name + ' default') def uninstallAgentService(self): return Run('rc-update del ' + self.agent_service_name + ' default') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self, p): if Run('eix -I ^' + p + '$', chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run('eix -u ^' + p + '$', chk_err=False): return 0 else: return 1 def RestartInterface(self, iface): Run("/etc/init.d/net." + iface + " restart") ############################################################ # SuSEDistro ############################################################ suse_init_file = """\ #! /bin/sh # # Azure Linux Agent sysV init script # # Copyright 2013 Microsoft Corporation # Copyright SUSE LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # /etc/init.d/waagent # # and symbolic link # # /usr/sbin/rcwaagent # # System startup script for the waagent # ### BEGIN INIT INFO # Provides: AzureLinuxAgent # Required-Start: $network sshd # Required-Stop: $network sshd # Default-Start: 3 5 # Default-Stop: 0 1 2 6 # Description: Start the AzureLinuxAgent ### END INIT INFO PYTHON=/usr/bin/python WAZD_BIN=/usr/sbin/waagent WAZD_CONF=/etc/waagent.conf WAZD_PIDFILE=/var/run/waagent.pid test -x "$WAZD_BIN" || { echo "$WAZD_BIN not installed"; exit 5; } test -e "$WAZD_CONF" || { echo "$WAZD_CONF not found"; exit 6; } . /etc/rc.status # First reset status of this service rc_reset # Return values acc. to LSB for all commands but status: # 0 - success # 1 - misc error # 2 - invalid or excess args # 3 - unimplemented feature (e.g. reload) # 4 - insufficient privilege # 5 - program not installed # 6 - program not configured # # Note that starting an already running service, stopping # or restarting a not-running service as well as the restart # with force-reload (in case signalling is not supported) are # considered a success. case "$1" in start) echo -n "Starting AzureLinuxAgent" ## Start daemon with startproc(8). If this fails ## the echo return value is set appropriate. startproc -f ${PYTHON} ${WAZD_BIN} -daemon rc_status -v ;; stop) echo -n "Shutting down AzureLinuxAgent" ## Stop daemon with killproc(8) and if this fails ## set echo the echo return value. killproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; try-restart) ## Stop the service and if this succeeds (i.e. the ## service was running before), start it again. $0 status >/dev/null && $0 restart rc_status ;; restart) ## Stop the service and regardless of whether it was ## running or not, start it again. $0 stop sleep 1 $0 start rc_status ;; force-reload|reload) rc_status ;; status) echo -n "Checking for service AzureLinuxAgent " ## Check status with checkproc(8), if process is running ## checkproc will return with exit status 0. checkproc -p ${WAZD_PIDFILE} ${PYTHON} ${WAZD_BIN} rc_status -v ;; probe) ;; *) echo "Usage: $0 {start|stop|status|try-restart|restart|force-reload|reload}" exit 1 ;; esac rc_exit """ class SuSEDistro(AbstractDistro): """ SuSE Distro concrete class Put SuSE specific behavior here... """ def __init__(self): super(SuSEDistro, self).__init__() self.service_cmd = '/sbin/service' self.ssh_service_name = 'sshd' self.kernel_boot_options_file = '/boot/grub/menu.lst' self.hostname_file_path = '/etc/HOSTNAME' self.requiredDeps += ["/sbin/insserv"] self.init_file = suse_init_file self.dhcp_client_name = 'dhcpcd' if ((DistInfo(fullname=1)[0] == 'SUSE Linux Enterprise Server' and DistInfo()[1] >= '12') or \ (DistInfo(fullname=1)[0] == 'openSUSE' and DistInfo()[1] >= '13.2')): self.dhcp_client_name = 'wickedd-dhcp4' self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' self.getpidcmd = 'pidof ' self.dhcp_enabled = True def checkPackageInstalled(self, p): if Run("rpm -q " + p, chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run("zypper list-updates | grep " + p, chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) except: pass def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('insserv ' + self.agent_service_name) def uninstallAgentService(self): return Run('insserv -r ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def startDHCP(self): Run("service " + self.dhcp_client_name + " start", chk_err=False) def stopDHCP(self): Run("service " + self.dhcp_client_name + " stop", chk_err=False) ############################################################ # redhatDistro ############################################################ redhat_init_file = """\ #!/bin/bash # # Init file for AzureLinuxAgent. # # chkconfig: 2345 60 80 # description: AzureLinuxAgent # # source function library . /etc/rc.d/init.d/functions RETVAL=0 FriendlyName="AzureLinuxAgent" WAZD_BIN=/usr/sbin/waagent start() { echo -n $"Starting $FriendlyName: " $WAZD_BIN -daemon & } stop() { echo -n $"Stopping $FriendlyName: " killproc -p /var/run/waagent.pid $WAZD_BIN RETVAL=$? echo return $RETVAL } case "$1" in start) start ;; stop) stop ;; restart) stop start ;; reload) ;; report) ;; status) status $WAZD_BIN RETVAL=$? ;; *) echo $"Usage: $0 {start|stop|restart|status}" RETVAL=1 esac exit $RETVAL """ class redhatDistro(AbstractDistro): """ Redhat Distro concrete class Put Redhat specific behavior here... """ def __init__(self): super(redhatDistro, self).__init__() self.service_cmd = '/sbin/service' self.ssh_service_restart_option = 'condrestart' self.ssh_service_name = 'sshd' self.hostname_file_path = None if DistInfo()[1] < '7.0' else '/etc/hostname' self.init_file = redhat_init_file self.grubKernelBootOptionsFile = '/boot/grub/menu.lst' self.grubKernelBootOptionsLine = 'kernel' def publishHostname(self, name): super(redhatDistro, self).publishHostname(name) if DistInfo()[1] < '7.0': filepath = "/etc/sysconfig/network" if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "HOSTNAME=" + name + "\n" + "\n".join( list(filter(lambda a: not a.startswith("HOSTNAME"), GetFileContents(filepath).split('\n'))))) ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join( list(filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n'))))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run('chkconfig --add waagent') def uninstallAgentService(self): return Run('chkconfig --del ' + self.agent_service_name) def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def checkPackageInstalled(self, p): if Run("yum list installed " + p, chk_err=False): return 0 else: return 1 def checkPackageUpdateable(self, p): if Run("yum check-update | grep " + p, chk_err=False): return 1 else: return 0 def checkDependencies(self): """ Generic dependency check. Return 1 unless all dependencies are satisfied. """ if DistInfo()[1] < '7.0' and self.checkPackageInstalled('NetworkManager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 try: m = __import__('pyasn1') except ImportError: Error(GuestAgentLongName + " requires python-pyasn1 for your Linux distribution.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 ############################################################ # centosDistro ############################################################ class centosDistro(redhatDistro): """ CentOS Distro concrete class Put CentOS specific behavior here... """ def __init__(self): super(centosDistro, self).__init__() ############################################################ # eulerosDistro ############################################################ class eulerosDistro(redhatDistro): """ EulerOS Distro concrete class Put EulerOS specific behavior here... """ def __init__(self): super(eulerosDistro, self).__init__() ############################################################ # oracleDistro ############################################################ class oracleDistro(redhatDistro): """ Oracle Distro concrete class Put Oracle specific behavior here... """ def __init__(self): super(oracleDistro, self).__init__() ############################################################ # asianuxDistro ############################################################ class asianuxDistro(redhatDistro): """ Asianux Distro concrete class Put Asianux specific behavior here... """ def __init__(self): super(asianuxDistro, self).__init__() ############################################################ # CoreOSDistro ############################################################ class CoreOSDistro(AbstractDistro): """ CoreOS Distro concrete class Put CoreOS specific behavior here... """ CORE_UID = 500 def __init__(self): super(CoreOSDistro, self).__init__() self.requiredDeps += ["/usr/bin/systemctl"] self.agent_service_name = 'waagent' self.init_script_file = '/etc/systemd/system/waagent.service' self.fileBlackList.append("/etc/machine-id") self.dhcp_client_name = 'systemd-networkd' self.getpidcmd = 'pidof ' self.shadow_file_mode = 0o640 self.waagent_path = '/usr/share/oem/bin' self.python_path = '/usr/share/oem/python/bin' self.dhcp_enabled = True if 'PATH' in os.environ: os.environ['PATH'] = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: os.environ['PATH'] = self.python_path if 'PYTHONPATH' in os.environ: os.environ['PYTHONPATH'] = "{0}:{1}".format(os.environ['PYTHONPATH'], self.waagent_path) else: os.environ['PYTHONPATH'] = self.waagent_path def checkPackageInstalled(self, p): """ There is no package manager in CoreOS. Return 1 since it must be preinstalled. """ return 1 def checkDependencies(self): for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self, p): """ There is no package manager in CoreOS. Return 0 since it can't be updated via package. """ return 0 def startAgentService(self): return Run('systemctl start ' + self.agent_service_name) def stopAgentService(self): return Run('systemctl stop ' + self.agent_service_name) def restartSshService(self): """ SSH is socket activated on CoreOS. No need to restart it. """ return 0 def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 def RestartInterface(self, iface): Run("systemctl restart systemd-networkd") def CreateAccount(self, user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin and userentry[2] != self.CORE_UID: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd --create-home --password '*' " + user if expiration != None: command += " --expiredate " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: self.changePass(user, password) try: if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def startDHCP(self): Run("systemctl start " + self.dhcp_client_name, chk_err=False) def stopDHCP(self): Run("systemctl stop " + self.dhcp_client_name, chk_err=False) def translateCustomData(self, data): return base64.b64decode(data) def getConfigurationPath(self): return "/usr/share/oem/waagent.conf" ############################################################ # debianDistro ############################################################ debian_init_file = """\ #!/bin/sh ### BEGIN INIT INFO # Provides: AzureLinuxAgent # Required-Start: $network $syslog # Required-Stop: $network $syslog # Should-Start: $network $syslog # Should-Stop: $network $syslog # Default-Start: 2 3 4 5 # Default-Stop: 0 1 6 # Short-Description: AzureLinuxAgent # Description: AzureLinuxAgent ### END INIT INFO . /lib/lsb/init-functions OPTIONS="-daemon" WAZD_BIN=/usr/sbin/waagent WAZD_PID=/var/run/waagent.pid case "$1" in start) log_begin_msg "Starting AzureLinuxAgent..." pid=$( pidofproc $WAZD_BIN ) if [ -n "$pid" ] ; then log_begin_msg "Already running." log_end_msg 0 exit 0 fi start-stop-daemon --start --quiet --oknodo --background --exec $WAZD_BIN -- $OPTIONS log_end_msg $? ;; stop) log_begin_msg "Stopping AzureLinuxAgent..." start-stop-daemon --stop --quiet --oknodo --pidfile $WAZD_PID ret=$? rm -f $WAZD_PID log_end_msg $ret ;; force-reload) $0 restart ;; restart) $0 stop $0 start ;; status) status_of_proc $WAZD_BIN && exit 0 || exit $? ;; *) log_success_msg "Usage: /etc/init.d/waagent {start|stop|force-reload|restart|status}" exit 1 ;; esac exit 0 """ class debianDistro(AbstractDistro): """ debian Distro concrete class Put debian specific behavior here... """ def __init__(self): super(debianDistro, self).__init__() self.requiredDeps += ["/usr/sbin/update-rc.d"] self.init_file = debian_init_file self.agent_package_name = 'walinuxagent' self.dhcp_client_name = 'dhclient' self.getpidcmd = 'pidof ' self.shadow_file_mode = 0o640 def checkPackageInstalled(self, p): """ Check that the package is installed. Return 1 if installed, 0 if not installed. This method of using dpkg-query allows wildcards to be present in the package name. """ if not Run("dpkg-query -W -f='${Status}\n' '" + p + "' | grep ' installed' 2>&1", chk_err=False): return 1 else: return 0 def checkDependencies(self): """ Debian dependency check. python-pyasn1 is NOT needed. Return 1 unless all dependencies are satisfied. NOTE: using network*manager will catch either package name in Ubuntu or debian. """ if self.checkPackageInstalled('network*manager'): Error(GuestAgentLongName + " is not compatible with network-manager.") return 1 for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def checkPackageUpdateable(self, p): if Run("apt-get update ; apt-get upgrade -us | grep " + p, chk_err=False): return 1 else: return 0 def installAgentServiceScriptFiles(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 try: SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o744) except OSError as e: ErrorWithPrefix('installAgentServiceScriptFiles', 'Exception: ' + str(e) + ' occured creating ' + self.init_script_file) return 1 return 0 def registerAgentService(self): if self.installAgentServiceScriptFiles() == 0: return Run('update-rc.d waagent defaults') else: return 1 def uninstallAgentService(self): return Run('update-rc.d -f ' + self.agent_service_name + ' remove') def unregisterAgentService(self): self.stopAgentService() return self.uninstallAgentService() def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 ############################################################ # KaliDistro - WIP # Functioning on Kali 1.1.0a so far ############################################################ class KaliDistro(debianDistro): """ Kali Distro concrete class Put Kali specific behavior here... """ def __init__(self): super(KaliDistro, self).__init__() ############################################################ # UbuntuDistro ############################################################ ubuntu_upstart_file = """\ #walinuxagent - start Azure agent description "walinuxagent" author "Ben Howard " start on (filesystem and started rsyslog) pre-start script WALINUXAGENT_ENABLED=1 [ -r /etc/default/walinuxagent ] && . /etc/default/walinuxagent if [ "$WALINUXAGENT_ENABLED" != "1" ]; then exit 1 fi if [ ! -x /usr/sbin/waagent ]; then exit 1 fi #Load the udf module modprobe -b udf end script exec /usr/sbin/waagent -daemon """ class UbuntuDistro(debianDistro): """ Ubuntu Distro concrete class Put Ubuntu specific behavior here... """ def __init__(self): super(UbuntuDistro, self).__init__() self.init_script_file = '/etc/init/waagent.conf' self.init_file = ubuntu_upstart_file self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log"] self.dhcp_client_name = None self.getpidcmd = 'pidof ' def registerAgentService(self): return self.installAgentServiceScriptFiles() def uninstallAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return 0 os.remove('/etc/init/' + self.agent_service_name + '.conf') def unregisterAgentService(self): """ If we are packaged - the service name is walinuxagent, do nothing. """ if self.agent_service_name == 'walinuxagent': return self.stopAgentService() return self.uninstallAgentService() def deprovisionWarnUser(self): """ Ubuntu specific warning string from Deprovision. """ print("WARNING! Nameserver configuration in /etc/resolvconf/resolv.conf.d/{tail,original} will be deleted.") def deprovisionDeleteFiles(self): """ Ubuntu uses resolv.conf by default, so removing /etc/resolv.conf will break resolvconf. Therefore, we check to see if resolvconf is in use, and if so, we remove the resolvconf artifacts. """ if os.path.realpath('/etc/resolv.conf') != '/run/resolvconf/resolv.conf': Log("resolvconf is not configured. Removing /etc/resolv.conf") self.fileBlackList.append('/etc/resolv.conf') else: Log("resolvconf is enabled; leaving /etc/resolv.conf intact") resolvConfD = '/etc/resolvconf/resolv.conf.d/' self.fileBlackList.extend([resolvConfD + 'tail', resolvConfD + 'original']) for f in os.listdir(LibDir) + self.fileBlackList: try: os.remove(f) except: pass return 0 def getDhcpClientName(self): if self.dhcp_client_name != None: return self.dhcp_client_name if DistInfo()[1] == '12.04': self.dhcp_client_name = 'dhclient3' else: self.dhcp_client_name = 'dhclient' return self.dhcp_client_name def waitForSshHostKey(self, path): """ Wait until the ssh host key is generated by cloud init. """ for retry in range(0, 10): if (os.path.isfile(path)): return True time.sleep(1) Error("Can't find host key: {0}".format(path)) return False ############################################################ # LinuxMintDistro ############################################################ class LinuxMintDistro(UbuntuDistro): """ LinuxMint Distro concrete class Put LinuxMint specific behavior here... """ def __init__(self): super(LinuxMintDistro, self).__init__() ############################################################ # DefaultDistro ############################################################ class DefaultDistro(UbuntuDistro): """ Default Distro concrete class Put Default distro specific behavior here... """ def __init__(self): super(DefaultDistro, self).__init__() ############################################################ # fedoraDistro ############################################################ fedora_systemd_service = """\ [Unit] Description=Azure Linux Agent After=network.target After=sshd.service ConditionFileIsExecutable=/usr/sbin/waagent ConditionPathExists=/etc/waagent.conf [Service] Type=simple ExecStart=/usr/sbin/waagent -daemon [Install] WantedBy=multi-user.target """ class fedoraDistro(redhatDistro): """ FedoraDistro concrete class Put Fedora specific behavior here... """ def __init__(self): super(fedoraDistro, self).__init__() self.service_cmd = '/usr/bin/systemctl' self.hostname_file_path = '/etc/hostname' self.init_script_file = '/usr/lib/systemd/system/' + self.agent_service_name + '.service' self.init_file = fedora_systemd_service self.grubKernelBootOptionsFile = '/etc/default/grub' self.grubKernelBootOptionsLine = 'GRUB_CMDLINE_LINUX=' def publishHostname(self, name): SetFileContents(self.hostname_file_path, name + '\n') ethernetInterface = MyDistro.GetInterfaceName() filepath = "/etc/sysconfig/network-scripts/ifcfg-" + ethernetInterface if os.path.isfile(filepath): ReplaceFileContentsAtomic(filepath, "DHCP_HOSTNAME=" + name + "\n" + "\n".join( list(filter(lambda a: not a.startswith("DHCP_HOSTNAME"), GetFileContents(filepath).split('\n'))))) return 0 def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o644) return Run(self.service_cmd + ' daemon-reload') def registerAgentService(self): self.installAgentServiceScriptFiles() return Run(self.service_cmd + ' enable ' + self.agent_service_name) def uninstallAgentService(self): """ Call service subsystem to remove waagent script. """ return Run(self.service_cmd + ' disable ' + self.agent_service_name) def unregisterAgentService(self): """ Calls self.stopAgentService and call self.uninstallAgentService() """ self.stopAgentService() self.uninstallAgentService() def startAgentService(self): """ Service call to start the Agent service """ return Run(self.service_cmd + ' start ' + self.agent_service_name) def stopAgentService(self): """ Service call to stop the Agent service """ return Run(self.service_cmd + ' stop ' + self.agent_service_name, False) def restartSshService(self): """ Service call to re(start) the SSH service """ sshRestartCmd = self.service_cmd + " " + self.ssh_service_restart_option + " " + self.ssh_service_name retcode = Run(sshRestartCmd) if retcode > 0: Error("Failed to restart SSH service with return code:" + str(retcode)) return retcode def deleteRootPassword(self): return Run("/sbin/usermod root -p '!!'") def packagedInstall(self, buildroot): """ Called from setup.py for use by RPM. Copies generated files waagent.conf, under the buildroot. """ if not os.path.exists(buildroot + '/etc'): os.mkdir(buildroot + '/etc') SetFileContents(buildroot + '/etc/waagent.conf', MyDistro.waagent_conf_file) if not os.path.exists(buildroot + '/etc/logrotate.d'): os.mkdir(buildroot + '/etc/logrotate.d') SetFileContents(buildroot + '/etc/logrotate.d/WALinuxAgent', WaagentLogrotate) self.init_script_file = buildroot + self.init_script_file # this allows us to call installAgentServiceScriptFiles() if not os.path.exists(os.path.dirname(self.init_script_file)): os.mkdir(os.path.dirname(self.init_script_file)) self.installAgentServiceScriptFiles() def CreateAccount(self, user, password, expiration, thumbprint): super(fedoraDistro, self).CreateAccount(user, password, expiration, thumbprint) Run('/sbin/usermod ' + user + ' -G wheel') def DeleteAccount(self, user): Run('/sbin/usermod ' + user + ' -G ""') super(fedoraDistro, self).DeleteAccount(user) ############################################################ # FreeBSD ############################################################ FreeBSDWaagentConf = """\ # # Azure Linux Agent Configuration # Role.StateConsumer=None # Specified program is invoked with the argument "Ready" when we report ready status # to the endpoint server. Role.ConfigurationConsumer=None # Specified program is invoked with XML file argument specifying role configuration. Role.TopologyConsumer=None # Specified program is invoked with XML file argument specifying role topology. Provisioning.Enabled=y # Provisioning.DeleteRootPassword=y # Password authentication for root account will be unavailable. Provisioning.RegenerateSshHostKeyPair=y # Generate fresh host key pair. Provisioning.SshHostKeyPairType=rsa # Supported values are "rsa", "dsa" and "ecdsa". Provisioning.MonitorHostName=y # Monitor host name changes and publish changes via DHCP requests. ResourceDisk.Format=y # Format if unformatted. If 'n', resource disk will not be mounted. ResourceDisk.Filesystem=ufs2 # ResourceDisk.MountPoint=/mnt/resource # ResourceDisk.EnableSwap=n # Create and use swapfile on resource disk. ResourceDisk.SwapSizeMB=0 # Size of the swapfile. LBProbeResponder=y # Respond to load balancer probes if requested by Azure. Logs.Verbose=n # Enable verbose logs OS.RootDeviceScsiTimeout=300 # Root device timeout in seconds. OS.OpensslPath=None # If "None", the system default version is used. """ bsd_init_file = """\ #! /bin/sh # PROVIDE: waagent # REQUIRE: DAEMON cleanvar sshd # BEFORE: LOGIN # KEYWORD: nojail . /etc/rc.subr export PATH=$PATH:/usr/local/bin name="waagent" rcvar="waagent_enable" command="/usr/sbin/${name}" command_interpreter="/usr/local/bin/python" waagent_flags=" daemon &" pidfile="/var/run/waagent.pid" load_rc_config $name run_rc_command "$1" """ bsd_activate_resource_disk_txt = """\ #!/usr/bin/env python import os import sys import imp # waagent has no '.py' therefore create waagent module import manually. __name__='setupmain' #prevent waagent.__main__ from executing waagent=imp.load_source('waagent','/tmp/waagent') waagent.LoggerInit('/var/log/waagent.log','/dev/console') from waagent import RunGetOutput,Run Config=waagent.ConfigurationProvider(None) format = Config.get("ResourceDisk.Format") if format == None or format.lower().startswith("n"): sys.exit(0) device_base = 'da1' device = "/dev/" + device_base for entry in RunGetOutput("mount")[1].split(): if entry.startswith(device + "s1"): waagent.Log("ActivateResourceDisk: " + device + "s1 is already mounted.") sys.exit(0) mountpoint = Config.get("ResourceDisk.MountPoint") if mountpoint == None: mountpoint = "/mnt/resource" waagent.CreateDir(mountpoint, "root", 0755) fs = Config.get("ResourceDisk.Filesystem") if waagent.FreeBSDDistro().mediaHasFilesystem(device) == False : Run("newfs " + device + "s1") if Run("mount " + device + "s1 " + mountpoint): waagent.Error("ActivateResourceDisk: Failed to mount resource disk (" + device + "s1).") sys.exit(0) waagent.Log("Resource disk (" + device + "s1) is mounted at " + mountpoint + " with fstype " + fs) waagent.SetFileContents(os.path.join(mountpoint,waagent.README_FILENAME), waagent.README_FILECONTENT) swap = Config.get("ResourceDisk.EnableSwap") if swap == None or swap.lower().startswith("n"): sys.exit(0) sizeKB = int(Config.get("ResourceDisk.SwapSizeMB")) * 1024 if os.path.isfile(mountpoint + "/swapfile") and os.path.getsize(mountpoint + "/swapfile") != (sizeKB * 1024): os.remove(mountpoint + "/swapfile") if not os.path.isfile(mountpoint + "/swapfile"): Run("umask 0077 && dd if=/dev/zero of=" + mountpoint + "/swapfile bs=1024 count=" + str(sizeKB)) if Run("mdconfig -a -t vnode -f " + mountpoint + "/swapfile -u 0"): waagent.Error("ActivateResourceDisk: Configuring swap - Failed to create md0") if not Run("swapon /dev/md0"): waagent.Log("Enabled " + str(sizeKB) + " KB of swap at " + mountpoint + "/swapfile") else: waagent.Error("ActivateResourceDisk: Failed to activate swap at " + mountpoint + "/swapfile") """ class FreeBSDDistro(AbstractDistro): """ """ def __init__(self): """ Generic Attributes go here. These are based on 'majority rules'. This __init__() may be called or overriden by the child. """ super(FreeBSDDistro, self).__init__() self.agent_service_name = os.path.basename(sys.argv[0]) self.selinux = False self.ssh_service_name = 'sshd' self.ssh_config_file = '/etc/ssh/sshd_config' self.hostname_file_path = '/etc/hostname' self.dhcp_client_name = 'dhclient' self.requiredDeps = ['route', 'shutdown', 'ssh-keygen', 'pw' , 'openssl', 'fdisk', 'sed', 'grep', 'sudo'] self.init_script_file = '/etc/rc.d/waagent' self.init_file = bsd_init_file self.agent_package_name = 'WALinuxAgent' self.fileBlackList = ["/root/.bash_history", "/var/log/waagent.log", '/etc/resolv.conf'] self.agent_files_to_uninstall = ["/etc/waagent.conf"] self.grubKernelBootOptionsFile = '/boot/loader.conf' self.grubKernelBootOptionsLine = '' self.getpidcmd = 'pgrep -n' self.mount_dvd_cmd = 'dd bs=2048 count=33 skip=295 if=' # custom data max len is 64k self.sudoers_dir_base = '/usr/local/etc' self.waagent_conf_file = FreeBSDWaagentConf def installAgentServiceScriptFiles(self): SetFileContents(self.init_script_file, self.init_file) os.chmod(self.init_script_file, 0o777) AppendFileContents("/etc/rc.conf", "waagent_enable='YES'\n") return 0 def registerAgentService(self): self.installAgentServiceScriptFiles() return Run("services_mkdb " + self.init_script_file) def sshDeployPublicKey(self, fprint, path): """ We support PKCS8. """ if Run("ssh-keygen -i -m PKCS8 -f " + fprint + " >> " + path): return 1 else: return 0 def deleteRootPassword(self): """ BSD root password removal. """ filepath = "/etc/master.passwd" ReplaceStringInFile(filepath, r'root:.*?:', 'root::') # ReplaceFileContentsAtomic(filepath,"root:*LOCK*:14600::::::\n" # + "\n".join(filter(lambda a: not a.startswith("root:"),GetFileContents(filepath).split('\n')))) os.chmod(filepath, self.shadow_file_mode) if self.isSelinuxSystem(): self.setSelinuxContext(filepath, 'system_u:object_r:shadow_t:s0') RunGetOutput("pwd_mkdb -u root /etc/master.passwd") Log("Root password deleted.") return 0 def changePass(self, user, password): return RunSendStdin("pw usermod " + user + " -h 0 ", password, log_cmd=False) def load_ata_piix(self): return 0 def unload_ata_piix(self): return 0 def checkDependencies(self): """ FreeBSD dependency check. Return 1 unless all dependencies are satisfied. """ for a in self.requiredDeps: if Run("which " + a + " > /dev/null 2>&1", chk_err=False): Error("Missing required dependency: " + a) return 1 return 0 def packagedInstall(self, buildroot): pass def GetInterfaceName(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() return iface def RestartInterface(self, iface): Run("service netif restart") def GetIpv4Address(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() return inet def GetMacAddress(self): """ Return the ip of the active ethernet interface. """ iface, inet, mac = self.GetFreeBSDEthernetInfo() l = mac.split(':') r = [] for i in l: r.append(int(i, 16)) return r def GetFreeBSDEthernetInfo(self): """ There is no SIOCGIFCONF on freeBSD - just parse ifconfig. Returns strings: iface, inet4_addr, and mac or 'None,None,None' if unable to parse. We will sleep and retry as the network must be up. """ code, output = RunGetOutput("ifconfig", chk_err=False) Log(output) retries = 10 cmd = 'ifconfig | grep -A2 -B2 ether | grep -B3 inet | grep -A4 UP ' code = 1 while code > 0: if code > 0 and retries == 0: Error("GetFreeBSDEthernetInfo - Failed to detect ethernet interface") return None, None, None code, output = RunGetOutput(cmd, chk_err=False) retries -= 1 if code > 0 and retries > 0: Log("GetFreeBSDEthernetInfo - Error: retry ethernet detection " + str(retries)) if retries == 9: c, o = RunGetOutput("ifconfig | grep -A1 -B2 ether", chk_err=False) if c == 0: t = o.replace('\n', ' ') t = t.split() i = t[0][:-1] Log(RunGetOutput('id')[1]) Run('dhclient ' + i) time.sleep(10) j = output.replace('\n', ' ') j = j.split() iface = j[0][:-1] for i in range(len(j)): if j[i] == 'inet': inet = j[i + 1] elif j[i] == 'ether': mac = j[i + 1] return iface, inet, mac def CreateAccount(self, user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "pw useradd " + user + " -m" if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: self.changePass(user, password) try: # for older distros create sudoers.d if not os.path.isdir(MyDistro.sudoers_dir_base + '/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir(MyDistro.sudoers_dir_base + '/sudoers.d') # add the include of sudoers.d to the /etc/sudoers SetFileContents(MyDistro.sudoers_dir_base + '/sudoers', GetFileContents( MyDistro.sudoers_dir_base + '/sudoers') + '\n#includedir ' + MyDistro.sudoers_dir_base + '/sudoers.d\n') if password == None: SetFileContents(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod(MyDistro.sudoers_dir_base + "/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(self, user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: if os.path.isfile("/etc/login.defs"): uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") # Delete utmp to prevent error if we are the 'user' deleted pid = subprocess.Popen(['rmuser', '-y', user], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE).pid try: os.remove(MyDistro.sudoers_dir_base + "/sudoers.d/waagent") except: pass return def ActivateResourceDiskNoThread(self): """ Format, mount, and if specified in the configuration set resource disk as swap. """ global DiskActivated Run('cp /usr/sbin/waagent /tmp/') SetFileContents('/tmp/bsd_activate_resource_disk.py', bsd_activate_resource_disk_txt) Run('chmod +x /tmp/bsd_activate_resource_disk.py') pid = subprocess.Popen(["/tmp/bsd_activate_resource_disk.py", ""]).pid Log("Spawning bsd_activate_resource_disk.py") DiskActivated = True return def Install(self): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0o755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a)) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", self.waagent_conf_file) if os.path.exists('/usr/local/etc/logrotate.d/'): SetFileContents("/usr/local/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(list(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split( '\n')))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") # ApplyVNUMAWorkaround() return 0 def mediaHasFilesystem(self, dsk): if Run('LC_ALL=C fdisk -p ' + dsk + ' | grep "invalid fdisk partition table found" ', False): return False return True def mountDVD(self, dvd, location): # At this point we cannot read a joliet option udf DVD in freebsd10 - so we 'dd' it into our location retcode, out = RunGetOutput(self.mount_dvd_cmd + dvd + ' of=' + location + '/ovf-env.xml') if retcode != 0: return retcode, out ovfxml = (GetFileContents(location + "/ovf-env.xml", asbin=False)) if ord(ovfxml[0]) > 128 and ord(ovfxml[1]) > 128 and ord(ovfxml[2]) > 128: ovfxml = ovfxml[ 3:] # BOM is not stripped. First three bytes are > 128 and not unicode chars so we ignore them. ovfxml = ovfxml.strip(chr(0x00)) ovfxml = "".join(list(filter(lambda x: ord(x) < 128, ovfxml))) ovfxml = re.sub(r'.*\Z', '', ovfxml, 0, re.DOTALL) ovfxml += '' SetFileContents(location + "/ovf-env.xml", ovfxml) return retcode, out def GetHome(self): return '/home' def initScsiDiskTimeout(self): """ Set the SCSI disk timeout by updating the kernal config """ timeout = Config.get("OS.RootDeviceScsiTimeout") if timeout: Run("sysctl kern.cam.da.default_timeout=" + timeout) def setScsiDiskTimeout(self): return def setBlockDeviceTimeout(self, device, timeout): return def getProcessorCores(self): return int(RunGetOutput("sysctl hw.ncpu | awk '{print $2}'")[1]) def getTotalMemory(self): return int(RunGetOutput("sysctl hw.realmem | awk '{print $2}'")[1]) / 1024 def setDefaultGateway(self, gateway): Run("/sbin/route add default " + gateway, chk_err=False) def routeAdd(self, net, mask, gateway): Run("/sbin/route add -net " + net + " " + mask + " " + gateway, chk_err=False) class NSBSDDistro(FreeBSDDistro): """ Stormhield NS-BSD OS """ def __init__(self): super(NSBSDDistro, self).__init__() ############################################################ # END DISTRO CLASS DEFS ############################################################ # This lets us index into a string or an array of integers transparently. def Ord(a): """ Allows indexing into a string or an array of integers transparently. Generic utility function. """ if type(a) == type("a"): a = ord(a) return a def IsLinux(): """ Returns True if platform is Linux. Generic utility function. """ return (platform.uname()[0] == "Linux") def GetLastPathElement(path): """ Similar to basename. Generic utility function. """ return path.rsplit('/', 1)[1] def GetFileContents(filepath, asbin=False): """ Read and return contents of 'filepath'. """ mode = 'r' if asbin: mode += 'b' c = None try: with open(filepath, mode) as F: c = F.read() except IOError as e: ErrorWithPrefix('GetFileContents', 'Reading from file ' + filepath + ' Exception is ' + str(e)) return None return c def SetFileContents(filepath, contents): """ Write 'contents' to 'filepath'. """ if type(contents) == str: contents = contents.encode('latin-1', 'ignore') try: with open(filepath, "wb+") as F: F.write(contents) except IOError as e: ErrorWithPrefix('SetFileContents', 'Writing to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def AppendFileContents(filepath, contents): """ Append 'contents' to 'filepath'. """ if type(contents) == str: contents = contents.encode('latin-1') try: with open(filepath, "a+") as F: F.write(contents) except IOError as e: ErrorWithPrefix('AppendFileContents', 'Appending to file ' + filepath + ' Exception is ' + str(e)) return None return 0 def ReplaceFileContentsAtomic(filepath, contents): """ Write 'contents' to 'filepath' by creating a temp file, and replacing original. """ handle, temp = tempfile.mkstemp(dir=os.path.dirname(filepath)) if type(contents) == str: contents = contents.encode('latin-1') try: os.write(handle, contents) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Writing to file ' + filepath + ' Exception is ' + str(e)) return None finally: os.close(handle) try: os.rename(temp, filepath) return None except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Renaming ' + temp + ' to ' + filepath + ' Exception is ' + str(e)) try: os.remove(filepath) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) try: os.rename(temp, filepath) except IOError as e: ErrorWithPrefix('ReplaceFileContentsAtomic', 'Removing ' + filepath + ' Exception is ' + str(e)) return 1 return 0 def GetLineStartingWith(prefix, filepath): """ Return line from 'filepath' if the line startswith 'prefix' """ for line in GetFileContents(filepath).split('\n'): if line.startswith(prefix): return line return None def Run(cmd, chk_err=True): """ Calls RunGetOutput on 'cmd', returning only the return code. If chk_err=True then errors will be reported in the log. If chk_err=False then errors will be suppressed from the log. """ retcode, out = RunGetOutput(cmd, chk_err) return retcode def RunGetOutput(cmd, chk_err=True, log_cmd=True): """ Wrapper for subprocess.check_output. Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: LogIfVerbose(cmd) try: output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) if isinstance(output, bytes): output = output.decode('latin-1') except subprocess.CalledProcessError as e: if chk_err and log_cmd: Error('CalledProcessError. Error Code is ' + str(e.returncode)) Error('CalledProcessError. Command string was ' + e.cmd) if isinstance(e.output[:-1], bytes): Error('CalledProcessError. Command result was ' + (e.output[:-1]).decode('latin-1')) else: Error('CalledProcessError. Command result was ' + (e.output[:-1])) if isinstance(e.output, bytes): return_value = e.output.decode('latin-1') else: return_value = e.output return e.returncode, return_value return 0, output def RunSendStdin(cmd, input, chk_err=True, log_cmd=True): """ Wrapper for subprocess.Popen. Execute 'cmd', sending 'input' to STDIN of 'cmd'. Returns return code and STDOUT, trapping expected exceptions. Reports exceptions to Error if chk_err parameter is True """ if log_cmd: LogIfVerbose(cmd + input) try: me = subprocess.Popen([cmd], shell=True, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) output = me.communicate(input) except OSError as e: if chk_err and log_cmd: Error('CalledProcessError. Error Code is ' + str(me.returncode)) Error('CalledProcessError. Command string was ' + cmd) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return 1, output[0].decode('latin-1') if me.returncode != 0 and chk_err is True and log_cmd: Error('CalledProcessError. Error Code is ' + str(me.returncode)) Error('CalledProcessError. Command string was ' + cmd) Error('CalledProcessError. Command result was ' + output[0].decode('latin-1')) return me.returncode, output[0].decode('latin-1') def GetNodeTextData(a): """ Filter non-text nodes from DOM tree """ for b in a.childNodes: if b.nodeType == b.TEXT_NODE: return b.data def GetHome(): """ Attempt to guess the $HOME location. Return the path string. """ home = None try: home = GetLineStartingWith("HOME", "/etc/default/useradd").split('=')[1].strip() except: pass if (home == None) or (home.startswith("/") == False): home = "/home" return home def ChangeOwner(filepath, user): """ Lookup user. Attempt chown 'filepath' to 'user'. """ p = None try: p = pwd.getpwnam(user) except: pass if p != None: if not os.path.exists(filepath): Error("Path does not exist: {0}".format(filepath)) else: os.chown(filepath, p[2], p[3]) def CreateDir(dirpath, user, mode): """ Attempt os.makedirs, catch all exceptions. Call ChangeOwner afterwards. """ try: os.makedirs(dirpath, mode) except: pass ChangeOwner(dirpath, user) def CreateAccount(user, password, expiration, thumbprint): """ Create a user account, with 'user', 'password', 'expiration', ssh keys and sudo permissions. Returns None if successful, error string on failure. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry != None and userentry[2] < uidmin: Error("CreateAccount: " + user + " is a system user. Will not set password.") return "Failed to set password for system user: " + user + " (0x06)." if userentry == None: command = "useradd -m " + user if expiration != None: command += " -e " + expiration.split('.')[0] if Run(command): Error("Failed to create user account: " + user) return "Failed to create user account: " + user + " (0x07)." else: Log("CreateAccount: " + user + " already exists. Will update password.") if password != None: MyDistro.changePass(user, password) try: # for older distros create sudoers.d if not os.path.isdir('/etc/sudoers.d/'): # create the /etc/sudoers.d/ directory os.mkdir('/etc/sudoers.d/') # add the include of sudoers.d to the /etc/sudoers SetFileContents('/etc/sudoers', GetFileContents('/etc/sudoers') + '\n#includedir /etc/sudoers.d\n') if password == None: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) NOPASSWD: ALL\n") else: SetFileContents("/etc/sudoers.d/waagent", user + " ALL = (ALL) ALL\n") os.chmod("/etc/sudoers.d/waagent", 0o440) except: Error("CreateAccount: Failed to configure sudo access for user.") return "Failed to configure sudo privileges (0x08)." home = MyDistro.GetHome() if thumbprint != None: dir = home + "/" + user + "/.ssh" CreateDir(dir, user, 0o700) pub = dir + "/id_rsa.pub" prv = dir + "/id_rsa" Run("ssh-keygen -y -f " + thumbprint + ".prv > " + pub) SetFileContents(prv, GetFileContents(thumbprint + ".prv")) for f in [pub, prv]: os.chmod(f, 0o600) ChangeOwner(f, user) SetFileContents(dir + "/authorized_keys", GetFileContents(pub)) ChangeOwner(dir + "/authorized_keys", user) Log("Created user account: " + user) return None def DeleteAccount(user): """ Delete the 'user'. Clear utmp first, to avoid error. Removes the /etc/sudoers.d/waagent file. """ userentry = None try: userentry = pwd.getpwnam(user) except: pass if userentry == None: Error("DeleteAccount: " + user + " not found.") return uidmin = None try: uidmin = int(GetLineStartingWith("UID_MIN", "/etc/login.defs").split()[1]) except: pass if uidmin == None: uidmin = 100 if userentry[2] < uidmin: Error("DeleteAccount: " + user + " is a system user. Will not delete account.") return Run("> /var/run/utmp") # Delete utmp to prevent error if we are the 'user' deleted Run("userdel -f -r " + user) try: os.remove("/etc/sudoers.d/waagent") except: pass return def IsInRangeInclusive(a, low, high): """ Return True if 'a' in 'low' <= a >= 'high' """ return (a >= low and a <= high) def IsPrintable(ch): """ Return True if character is displayable. """ return IsInRangeInclusive(ch, Ord('A'), Ord('Z')) or IsInRangeInclusive(ch, Ord('a'), Ord('z')) or IsInRangeInclusive(ch, Ord('0'), Ord('9')) def HexDump(buffer, size): """ Return Hex formated dump of a 'buffer' of 'size'. """ if size < 0: size = len(buffer) result = "" for i in range(0, size): if (i % 16) == 0: result += "%06X: " % i byte = buffer[i] if type(byte) == str: byte = ord(byte.decode('latin1')) result += "%02X " % byte if (i & 15) == 7: result += " " if ((i + 1) % 16) == 0 or (i + 1) == size: j = i while ((j + 1) % 16) != 0: result += " " if (j & 7) == 7: result += " " j += 1 result += " " for j in range(i - (i % 16), i + 1): byte = buffer[j] if type(byte) == str: byte = ord(byte.decode('latin1')) k = '.' if IsPrintable(byte): k = chr(byte) result += k if (i + 1) != size: result += "\n" return result def SimpleLog(file_path, message): if not file_path or len(message) < 1: return t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) lines = re.sub(re.compile(r'^(.)', re.MULTILINE), t + r'\1', message) with open(file_path, "ab") as F: lines = "".join(list(filter(lambda x: x in string.printable, lines))) F.write((lines + "\n").encode('ascii','ignore')) class Logger(object): """ The Agent's logging assumptions are: For Log, and LogWithPrefix all messages are logged to the self.file_path and to the self.con_path. Setting either path parameter to None skips that log. If Verbose is enabled, messages calling the LogIfVerbose method will be logged to file_path yet not to con_path. Error and Warn messages are normal log messages with the 'ERROR:' or 'WARNING:' prefix added. """ def __init__(self, filepath, conpath, verbose=False): """ Construct an instance of Logger. """ self.file_path = filepath self.con_path = conpath self.verbose = verbose def ThrottleLog(self, counter): """ Log everything up to 10, every 10 up to 100, then every 100. """ return (counter < 10) or ((counter < 100) and ((counter % 10) == 0)) or ((counter % 100) == 0) def LogToFile(self, message): """ Write 'message' to logfile. """ if self.file_path: try: with open(self.file_path, "ab") as F: message = "".join(list(filter(lambda x: x in string.printable, message))) F.write((message + "\n").encode('ascii','ignore')) except IOError as e: ##print e pass def LogToCon(self, message): """ Write 'message' to /dev/console. This supports serial port logging if the /dev/console is redirected to ttys0 in kernel boot options. """ if self.con_path: try: with open(self.con_path, "wb") as C: message = "".join(list(filter(lambda x: x in string.printable, message))) C.write((message + "\n").encode('ascii','ignore')) except IOError as e: pass def Log(self, message): """ Standard Log function. Logs to self.file_path, and con_path """ self.LogWithPrefix("", message) def LogWithPrefix(self, prefix, message): """ Prefix each line of 'message' with current time+'prefix'. """ t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) t += prefix for line in message.split('\n'): line = t + line self.LogToFile(line) self.LogToCon(line) def NoLog(self, message): """ Don't Log. """ pass def LogIfVerbose(self, message): """ Only log 'message' if global Verbose is True. """ self.LogWithPrefixIfVerbose('', message) def LogWithPrefixIfVerbose(self, prefix, message): """ Only log 'message' if global Verbose is True. Prefix each line of 'message' with current time+'prefix'. """ if self.verbose == True: t = time.localtime() t = "%04u/%02u/%02u %02u:%02u:%02u " % (t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec) t += prefix for line in message.split('\n'): line = t + line self.LogToFile(line) self.LogToCon(line) def Warn(self, message): """ Prepend the text "WARNING:" to the prefix for each line in 'message'. """ self.LogWithPrefix("WARNING:", message) def Error(self, message): """ Call ErrorWithPrefix(message). """ ErrorWithPrefix("", message) def ErrorWithPrefix(self, prefix, message): """ Prepend the text "ERROR:" to the prefix for each line in 'message'. Errors written to logfile, and /dev/console """ self.LogWithPrefix("ERROR:", message) def LoggerInit(log_file_path, log_con_path, verbose=False): """ Create log object and export its methods to global scope. """ global Log, LogWithPrefix, LogIfVerbose, LogWithPrefixIfVerbose, Error, ErrorWithPrefix, Warn, NoLog, ThrottleLog, myLogger l = Logger(log_file_path, log_con_path, verbose) Log, LogWithPrefix, LogIfVerbose, LogWithPrefixIfVerbose, Error, ErrorWithPrefix, Warn, NoLog, ThrottleLog, myLogger = l.Log, l.LogWithPrefix, l.LogIfVerbose, l.LogWithPrefixIfVerbose, l.Error, l.ErrorWithPrefix, l.Warn, l.NoLog, l.ThrottleLog, l class HttpResourceGoneError(Exception): pass class Util(object): """ Http communication class. Base of GoalState, and Agent classes. """ RetryWaitingInterval = 10 def __init__(self): self.Endpoint = None def _ParseUrl(self, url): secure = False host = self.Endpoint path = url port = None # "http[s]://hostname[:port][/]" if url.startswith("http://"): url = url[7:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" elif url.startswith("https://"): secure = True url = url[8:] if "/" in url: host = url[0: url.index("/")] path = url[url.index("/"):] else: host = url path = "/" if host is None: raise ValueError("Host is invalid:{0}".format(url)) if (":" in host): pos = host.rfind(":") port = int(host[pos + 1:]) host = host[0:pos] return host, port, secure, path def GetHttpProxy(self, secure): """ Get http_proxy and https_proxy from environment variables. Username and password is not supported now. """ host = Config.get("HttpProxy.Host") port = Config.get("HttpProxy.Port") return (host, port) def _HttpRequest(self, method, host, path, port=None, data=None, secure=False, headers=None, proxyHost=None, proxyPort=None): resp = None conn = None try: if secure: port = 443 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httplibs.HTTPSConnection(proxyHost, proxyPort, timeout=10) conn.set_tunnel(host, port) # If proxy is used, full url is needed. path = "https://{0}:{1}{2}".format(host, port, path) else: conn = httplibs.HTTPSConnection(host, port, timeout=10) else: port = 80 if port is None else port if proxyHost is not None and proxyPort is not None: conn = httplibs.HTTPConnection(proxyHost, proxyPort, timeout=10) # If proxy is used, full url is needed. path = "http://{0}:{1}{2}".format(host, port, path) else: conn = httplibs.HTTPConnection(host, port, timeout=10) if headers == None: conn.request(method, path, data) else: conn.request(method, path, data, headers) resp = conn.getresponse() except httplibs.HTTPException as e: Error('HTTPException {0}, args:{1}'.format(e, repr(e.args))) except IOError as e: Error('Socket IOError {0}, args:{1}'.format(e, repr(e.args))) return resp def HttpRequest(self, method, url, data=None, headers=None, maxRetry=3, chkProxy=False): """ Sending http request to server On error, sleep 10 and maxRetry times. Return the output buffer or None. """ LogIfVerbose("HTTP Req: {0} {1}".format(method, url)) LogIfVerbose("HTTP Req: Data={0}".format(data)) LogIfVerbose("HTTP Req: Header={0}".format(headers)) try: host, port, secure, path = self._ParseUrl(url) except ValueError as e: Error("Failed to parse url:{0}".format(url)) return None # Check proxy proxyHost, proxyPort = (None, None) if chkProxy: proxyHost, proxyPort = self.GetHttpProxy(secure) # If httplib module is not built with ssl support. Fallback to http if secure and not hasattr(httplibs, "HTTPSConnection"): Warn("httplib is not built with ssl support") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) # If httplib module doesn't support https tunnelling. Fallback to http if secure and \ proxyHost is not None and \ proxyPort is not None and \ not hasattr(httplibs.HTTPSConnection, "set_tunnel"): Warn("httplib doesn't support https tunnelling(new in python 2.7)") secure = False proxyHost, proxyPort = self.GetHttpProxy(secure) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) for retry in range(0, maxRetry): if resp is not None and \ (resp.status == httplibs.OK or \ resp.status == httplibs.CREATED or \ resp.status == httplibs.ACCEPTED): return resp if resp is not None and resp.status == httplibs.GONE: raise HttpResourceGoneError("Http resource gone.") Error("Retry={0}".format(retry)) Error("HTTP Req: {0} {1}".format(method, url)) Error("HTTP Req: Data={0}".format(data)) Error("HTTP Req: Header={0}".format(headers)) if resp is None: Error("HTTP Err: response is empty. {0}".format(retry)) else: Error("HTTP Err: Status={0}".format(resp.status)) Error("HTTP Err: Reason={0}".format(resp.reason)) Error("HTTP Err: Header={0}".format(resp.getheaders())) Error("HTTP Err: Body={0}".format(resp.read())) time.sleep(self.__class__.RetryWaitingInterval) resp = self._HttpRequest(method, host, path, port=port, data=data, secure=secure, headers=headers, proxyHost=proxyHost, proxyPort=proxyPort) return None def HttpGet(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("GET", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpHead(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("HEAD", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPost(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("POST", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpPut(self, url, data, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("PUT", url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpDelete(self, url, headers=None, maxRetry=3, chkProxy=False): return self.HttpRequest("DELETE", url, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) def HttpGetWithoutHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url'. """ resp = self.HttpGet(url, headers=None, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpGetWithHeaders(self, url, maxRetry=3, chkProxy=False): """ Return data from an HTTP get on 'url' with x-ms-agent-name and x-ms-version headers. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpSecureGetWithHeaders(self, url, transportCert, maxRetry=3, chkProxy=False): """ Return output of get using ssl cert. """ resp = self.HttpGet(url, headers={ "x-ms-agent-name": GuestAgentName, "x-ms-version": ProtocolVersion, "x-ms-cipher-name": "DES_EDE3_CBC", "x-ms-guest-agent-public-x509-cert": transportCert }, maxRetry=maxRetry, chkProxy=chkProxy) return resp.read() if resp is not None else None def HttpPostWithHeaders(self, url, data, maxRetry=3, chkProxy=False): headers = { "x-ms-agent-name": GuestAgentName, "Content-Type": "text/xml; charset=utf-8", "x-ms-version": ProtocolVersion } try: return self.HttpPost(url, data=data, headers=headers, maxRetry=maxRetry, chkProxy=chkProxy) except HttpResourceGoneError as e: Error("Failed to post: {0} {1}".format(url, e)) return None __StorageVersion = "2014-02-14" def GetBlobType(url): restutil = Util() # Check blob type LogIfVerbose("Check blob type.") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) blobPropResp = restutil.HttpHead(url, { "x-ms-date": timestamp, 'x-ms-version': __StorageVersion }, chkProxy=True); blobType = None if blobPropResp is None: Error("Can't get status blob type.") return None blobType = blobPropResp.getheader("x-ms-blob-type") LogIfVerbose("Blob type={0}".format(blobType)) return blobType def PutBlockBlob(url, data): restutil = Util() LogIfVerbose("Upload block blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) ret = restutil.HttpPut(url, data, { "x-ms-date": timestamp, "x-ms-blob-type": "BlockBlob", "Content-Length": str(len(data)), "x-ms-version": __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to upload block blob for status.") return -1 return 0 def PutPageBlob(url, data): restutil = Util() LogIfVerbose("Replace old page blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) # Align to 512 bytes pageBlobSize = ((len(data) + 511) / 512) * 512 ret = restutil.HttpPut(url, "", { "x-ms-date": timestamp, "x-ms-blob-type": "PageBlob", "Content-Length": "0", "x-ms-blob-content-length": str(pageBlobSize), "x-ms-version": __StorageVersion }, chkProxy=True) if ret is None: Error("Failed to clean up page blob for status") return -1 if url.index('?') < 0: url = "{0}?comp=page".format(url) else: url = "{0}&comp=page".format(url) LogIfVerbose("Upload page blob") pageMax = 4 * 1024 * 1024 # Max page size: 4MB start = 0 end = 0 while end < len(data): end = min(len(data), start + pageMax) contentSize = end - start # Align to 512 bytes pageEnd = ((end + 511) / 512) * 512 bufSize = pageEnd - start buf = bytearray(bufSize) buf[0: contentSize] = data[start: end] if sys.version_info > (3,): buffer = memoryview ret = restutil.HttpPut(url, buffer(buf), { "x-ms-date": timestamp, "x-ms-range": "bytes={0}-{1}".format(start, pageEnd - 1), "x-ms-page-write": "update", "x-ms-version": __StorageVersion, "Content-Length": str(pageEnd - start) }, chkProxy=True) if ret is None: Error("Failed to upload page blob for status") return -1 start = end return 0 def UploadStatusBlob(url, data): LogIfVerbose("Upload status blob") LogIfVerbose("Status={0}".format(data)) blobType = GetBlobType(url) if blobType == "BlockBlob": return PutBlockBlob(url, data) elif blobType == "PageBlob": return PutPageBlob(url, data) else: Error("Unknown blob type: {0}".format(blobType)) return -1 class TCPHandler(SocketServers.BaseRequestHandler): """ Callback object for LoadBalancerProbeServer. Recv and send LB probe messages. """ def __init__(self, lb_probe): super(TCPHandler, self).__init__() self.lb_probe = lb_probe def GetHttpDateTimeNow(self): """ Return formatted gmtime "Date: Fri, 25 Mar 2011 04:53:10 GMT" """ return time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) def handle(self): """ Log LB probe messages, read the socket buffer, send LB probe response back to server. """ self.lb_probe.ProbeCounter = (self.lb_probe.ProbeCounter + 1) % 1000000 log = [NoLog, LogIfVerbose][ThrottleLog(self.lb_probe.ProbeCounter)] strCounter = str(self.lb_probe.ProbeCounter) if self.lb_probe.ProbeCounter == 1: Log("Receiving LB probes.") log("Received LB probe # " + strCounter) self.request.recv(1024) self.request.send( "HTTP/1.1 200 OK\r\nContent-Length: 2\r\nContent-Type: text/html\r\nDate: " + self.GetHttpDateTimeNow() + "\r\n\r\nOK") class LoadBalancerProbeServer(object): """ Threaded object to receive and send LB probe messages. Load Balancer messages but be recv'd by the load balancing server, or this node may be shut-down. """ def __init__(self, port): self.ProbeCounter = 0 self.server = SocketServers.TCPServer((self.get_ip(), port), TCPHandler) self.server_thread = threading.Thread(target=self.server.serve_forever) self.server_thread.setDaemon(True) self.server_thread.start() def shutdown(self): self.server.shutdown() def get_ip(self): for retry in range(1, 6): ip = MyDistro.GetIpv4Address() if ip == None: Log("LoadBalancerProbeServer: GetIpv4Address() returned None, sleeping 10 before retry " + str( retry + 1)) time.sleep(10) else: return ip class ConfigurationProvider(object): """ Parse amd store key:values in waagent.conf """ def __init__(self, walaConfigFile): self.values = dict() if 'MyDistro' not in globals(): global MyDistro MyDistro = GetMyDistro() if walaConfigFile is None: walaConfigFile = MyDistro.getConfigurationPath() if os.path.isfile(walaConfigFile) == False: raise Exception("Missing configuration in {0}".format(walaConfigFile)) try: for line in GetFileContents(walaConfigFile).split('\n'): if not line.startswith("#") and "=" in line: parts = line.split()[0].split('=') value = parts[1].strip("\" ") if value != "None": self.values[parts[0]] = value else: self.values[parts[0]] = None except: Error("Unable to parse {0}".format(walaConfigFile)) raise return def get(self, key): return self.values.get(key) class EnvMonitor(object): """ Montor changes to dhcp and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ def __init__(self): self.shutdown = False self.HostName = socket.gethostname() self.server_thread = threading.Thread(target=self.monitor) self.server_thread.setDaemon(True) self.server_thread.start() self.published = False def monitor(self): """ Monitor dhcp client pid and hostname. If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. """ publish = Config.get("Provisioning.MonitorHostName") dhcpcmd = MyDistro.getpidcmd + ' ' + MyDistro.getDhcpClientName() dhcppid = RunGetOutput(dhcpcmd)[1] while not self.shutdown: for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Log("EnvMonitor: Moved " + a + " -> " + LibDir) MyDistro.setScsiDiskTimeout() if publish != None and publish.lower().startswith("y"): try: if socket.gethostname() != self.HostName: Log("EnvMonitor: Detected host name change: " + self.HostName + " -> " + socket.gethostname()) self.HostName = socket.gethostname() WaAgent.UpdateAndPublishHostName(self.HostName) dhcppid = RunGetOutput(dhcpcmd)[1] self.published = True except: pass else: self.published = True pid = "" if not os.path.isdir("/proc/" + dhcppid.strip()): pid = RunGetOutput(dhcpcmd)[1] if pid != "" and pid != dhcppid: Log("EnvMonitor: Detected dhcp client restart. Restoring routing table.") WaAgent.RestoreRoutes() dhcppid = pid for child in Children: if child.poll() != None: Children.remove(child) time.sleep(5) def SetHostName(self, name): """ Generic call to MyDistro.setHostname(name). Complian to Log on error. """ if socket.gethostname() == name: self.published = True elif MyDistro.setHostname(name): Error("Error: SetHostName: Cannot set hostname to " + name) return ("Error: SetHostName: Cannot set hostname to " + name) def IsHostnamePublished(self): """ Return self.published """ return self.published def ShutdownService(self): """ Stop server comminucation and join the thread to main thread. """ self.shutdown = True self.server_thread.join() class Certificates(object): """ Object containing certificates of host and provisioned user. Parses and splits certificates into files. """ # # 2010-12-15 # 2 # Pkcs7BlobWithPfxContents # MIILTAY... # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset the Role, Incarnation """ self.Incarnation = None self.Role = None def Parse(self, xmlText): """ Parse multiple certificates into seperate files. """ self.reinitialize() SetFileContents("Certificates.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in ["CertificateFile", "Version", "Incarnation", "Format", "Data", ]: if not dom.getElementsByTagName(a): Error("Certificates.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "CertificateFile": Error("Certificates.Parse: root not CertificateFile") return None SetFileContents("Certificates.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"Certificates.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"Certificates.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + GetNodeTextData(dom.getElementsByTagName("Data")[0])) if Run( Openssl + " cms -decrypt -in Certificates.p7m -inkey TransportPrivate.pem -recip TransportCert.pem | " + Openssl + " pkcs12 -nodes -password pass: -out Certificates.pem"): Error("Certificates.Parse: Failed to extract certificates from CMS message.") return self # There may be multiple certificates in this package. Split them. file = open("Certificates.pem") pindex = 1 cindex = 1 output = open("temp.pem", "w") for line in file.readlines(): output.write(line) if re.match(r'[-]+END .*?(KEY|CERTIFICATE)[-]+$', line): output.close() if re.match(r'[-]+END .*?KEY[-]+$', line): os.rename("temp.pem", str(pindex) + ".prv") pindex += 1 else: os.rename("temp.pem", str(cindex) + ".crt") cindex += 1 output = open("temp.pem", "w") output.close() os.remove("temp.pem") keys = dict() index = 1 filename = str(index) + ".crt" while os.path.isfile(filename): thumbprint = \ (RunGetOutput(Openssl + " x509 -in " + filename + " -fingerprint -noout")[1]).rstrip().split('=')[ 1].replace(':', '').upper() pubkey = RunGetOutput(Openssl + " x509 -in " + filename + " -pubkey -noout")[1] keys[pubkey] = thumbprint os.rename(filename, thumbprint + ".crt") os.chmod(thumbprint + ".crt", 0o600) MyDistro.setSelinuxContext(thumbprint + '.crt', 'unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".crt" index = 1 filename = str(index) + ".prv" while os.path.isfile(filename): pubkey = RunGetOutput(Openssl + " rsa -in " + filename + " -pubout 2> /dev/null ")[1] os.rename(filename, keys[pubkey] + ".prv") os.chmod(keys[pubkey] + ".prv", 0o600) MyDistro.setSelinuxContext(keys[pubkey] + '.prv', 'unconfined_u:object_r:ssh_home_t:s0') index += 1 filename = str(index) + ".prv" return self class ExtensionsConfig(object): """ Parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. Install if true, remove if it is set to false. """ # # # # # # # {"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1", # "protectedSettings":"MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR # Xh0ZW5zaW9ucwIQZi7dw+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6 # tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/X # v1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqh # kiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]} # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset members. """ self.Extensions = None self.Plugins = None self.Util = None def Parse(self, xmlText): """ Write configuration to file ExtensionsConfig.xml. Log plugin specific activity to /var/log/azure/.//CommandExecution.log. If state is enabled: if the plugin is installed: if the new plugin's version is higher if DisallowMajorVersionUpgrade is false or if true, the version is a minor version do upgrade: download the new archive do the updateCommand. disable the old plugin and remove enable the new plugin if the new plugin's version is the same or lower: create the new .settings file from the configuration received do the enableCommand if the plugin is not installed: download/unpack archive and call the installCommand/Enable if state is disabled: call disableCommand if state is uninstall: call uninstallCommand remove old plugin directory. """ self.reinitialize() self.Util = Util() dom = xml.dom.minidom.parseString(xmlText) LogIfVerbose(xmlText) self.plugin_log_dir = '/var/log/azure' if not os.path.exists(self.plugin_log_dir): os.mkdir(self.plugin_log_dir) try: self.Extensions = dom.getElementsByTagName("Extensions") pg = dom.getElementsByTagName("Plugins") if len(pg) > 0: self.Plugins = pg[0].getElementsByTagName("Plugin") else: self.Plugins = [] incarnation = self.Extensions[0].getAttribute("goalStateIncarnation") SetFileContents('ExtensionsConfig.' + incarnation + '.xml', xmlText) except Exception as e: Error('ERROR: Error parsing ExtensionsConfig: {0}.'.format(e)) return None for p in self.Plugins: if len(p.getAttribute("location")) < 1: # this plugin is inside the PluginSettings continue p.setAttribute('restricted', 'false') previous_version = None version = p.getAttribute("version") name = p.getAttribute("name") plog_dir = self.plugin_log_dir + '/' + name + '/' + version if not os.path.exists(plog_dir): os.makedirs(plog_dir) p.plugin_log = plog_dir + '/CommandExecution.log' handler = name + '-' + version if p.getAttribute("isJson") != 'true': Error("Plugin " + name + " version: " + version + " is not a JSON Extension. Skipping.") continue Log("Found Plugin: " + name + ' version: ' + version) if p.getAttribute("state") == 'disabled' or p.getAttribute("state") == 'uninstall': # disable zip_dir = LibDir + "/" + name + '-' + version mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) if self.launchCommand(p.plugin_log, name, version, 'disableCommand') == None: self.SetHandlerState(handler, 'Enabled') Error('Unable to disable ' + name) SimpleLog(p.plugin_log, 'ERROR: Unable to disable ' + name) else: self.SetHandlerState(handler, 'Disabled') Log(name + ' is disabled') SimpleLog(p.plugin_log, name + ' is disabled') # uninstall if needed if p.getAttribute("state") == 'uninstall': if self.launchCommand(p.plugin_log, name, version, 'uninstallCommand') == None: self.SetHandlerState(handler, 'Installed') Error('Unable to uninstall ' + name) SimpleLog(p.plugin_log, 'Unable to uninstall ' + name) else: self.SetHandlerState(handler, 'NotInstalled') Log(name + ' uninstallCommand completed .') # remove the plugin Run('rm -rf ' + LibDir + '/' + name + '-' + version + '*') Log(name + '-' + version + ' extension files deleted.') SimpleLog(p.plugin_log, name + '-' + version + ' extension files deleted.') continue # state is enabled # if the same plugin exists and the version is newer or # does not exist then download and unzip the new plugin plg_dir = None latest_version_installed = LooseVersion("0.0") for item in os.listdir(LibDir): itemPath = os.path.join(LibDir, item) if os.path.isdir(itemPath) and name in item: try: # Split plugin dir name with '-' to get intalled plugin name and version sperator = item.rfind('-') if sperator < 0: continue installed_plg_name = item[0:sperator] installed_plg_version = LooseVersion(item[sperator + 1:]) # Check installed plugin name and compare installed version to get the latest version installed if installed_plg_name == name and installed_plg_version > latest_version_installed: plg_dir = itemPath previous_version = str(installed_plg_version) latest_version_installed = installed_plg_version except Exception as e: Warn("Invalid plugin dir name: {0} {1}".format(item, e)) continue if plg_dir == None or LooseVersion(version) > LooseVersion(previous_version): location = p.getAttribute("location") Log("Downloading plugin manifest: " + name + " from " + location) SimpleLog(p.plugin_log, "Downloading plugin manifest: " + name + " from " + location) self.Util.Endpoint = location.split('/')[2] Log("Plugin server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log, "Plugin server is: " + self.Util.Endpoint) manifest = self.Util.HttpGetWithoutHeaders(location, chkProxy=True) if manifest == None: Error( "Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") SimpleLog(p.plugin_log, "Unable to download plugin manifest" + name + " from primary location. Attempting with failover location.") failoverlocation = p.getAttribute("failoverlocation") self.Util.Endpoint = failoverlocation.split('/')[2] Log("Plugin failover server is: " + self.Util.Endpoint) SimpleLog(p.plugin_log, "Plugin failover server is: " + self.Util.Endpoint) manifest = self.Util.HttpGetWithoutHeaders(failoverlocation, chkProxy=True) # if failoverlocation also fail what to do then? if manifest == None: AddExtensionEvent(name, WALAEventOperation.Download, False, 0, version, "Download mainfest fail " + failoverlocation) Log("Plugin manifest " + name + " downloading failed from failover location.") SimpleLog(p.plugin_log, "Plugin manifest " + name + " downloading failed from failover location.") filepath = LibDir + "/" + name + '.' + incarnation + '.manifest' if os.path.splitext(location)[-1] == '.xml': # if this is an xml file we may have a BOM if ord(manifest[0]) > 128 and ord(manifest[1]) > 128 and ord(manifest[2]) > 128: manifest = manifest[3:] SetFileContents(filepath, manifest) # Get the bundle url from the manifest p.setAttribute('manifestdata', manifest) man_dom = xml.dom.minidom.parseString(manifest) bundle_uri = "" for mp in man_dom.getElementsByTagName("Plugin"): if GetNodeTextData(mp.getElementsByTagName("Version")[0]) == version: bundle_uri = GetNodeTextData(mp.getElementsByTagName("Uri")[0]) break if len(mp.getElementsByTagName("DisallowMajorVersionUpgrade")): if GetNodeTextData(mp.getElementsByTagName("DisallowMajorVersionUpgrade")[ 0]) == 'true' and previous_version != None and previous_version.split('.')[ 0] != version.split('.')[0]: Log('DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') SimpleLog(p.plugin_log, 'DisallowMajorVersionUpgrade is true, this major version is restricted from upgrade.') p.setAttribute('restricted', 'true') continue if len(bundle_uri) < 1: Error("Unable to fetch Bundle URI from manifest for " + name + " v " + version) SimpleLog(p.plugin_log, "Unable to fetch Bundle URI from manifest for " + name + " v " + version) continue Log("Bundle URI = " + bundle_uri) SimpleLog(p.plugin_log, "Bundle URI = " + bundle_uri) # Download the zipfile archive and save as '.zip' bundle = self.Util.HttpGetWithoutHeaders(bundle_uri, chkProxy=True) if bundle == None: AddExtensionEvent(name, WALAEventOperation.Download, True, 0, version, "Download zip fail " + bundle_uri) Error("Unable to download plugin bundle" + bundle_uri) SimpleLog(p.plugin_log, "Unable to download plugin bundle" + bundle_uri) continue AddExtensionEvent(name, WALAEventOperation.Download, True, 0, version, "Download Success") b = bytearray(bundle) filepath = LibDir + "/" + os.path.basename(bundle_uri) + '.zip' SetFileContents(filepath, b) Log("Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) SimpleLog(p.plugin_log, "Plugin bundle" + bundle_uri + "downloaded successfully length = " + str(len(bundle))) # unpack the archive z = zipfile.ZipFile(filepath) zip_dir = LibDir + "/" + name + '-' + version z.extractall(zip_dir) Log('Extracted ' + bundle_uri + ' to ' + zip_dir) SimpleLog(p.plugin_log, 'Extracted ' + bundle_uri + ' to ' + zip_dir) # zip no file perms in .zip so set all the scripts to +x Run("find " + zip_dir + " -type f | xargs chmod u+x ") # write out the base64 config data so the plugin can process it. mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log, 'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) # create the status and config dirs Run('mkdir -p ' + root + '/status') Run('mkdir -p ' + root + '/config') # write out the configuration data to goalStateIncarnation.settings file in the config path. config = '' seqNo = '0' if len(dom.getElementsByTagName("PluginSettings")) != 0: pslist = dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "Found RuntimeSettings for " + name + " V " + version) config = GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo = ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Log("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "No RuntimeSettings for " + name + " V " + version) SetFileContents(root + "/config/" + seqNo + ".settings", config) # create HandlerEnvironment.json handler_env = '[{ "name": "' + name + '", "seqNo": "' + seqNo + '", "version": 1.0, "handlerEnvironment": { "logFolder": "' + os.path.dirname( p.plugin_log) + '", "configFolder": "' + root + '/config", "statusFolder": "' + root + '/status", "heartbeatFile": "' + root + '/heartbeat.log"}}]' SetFileContents(root + '/HandlerEnvironment.json', handler_env) self.SetHandlerState(handler, 'NotInstalled') cmd = '' getcmd = 'installCommand' if plg_dir != None and previous_version != None and LooseVersion(version) > LooseVersion( previous_version): previous_handler = name + '-' + previous_version if self.GetHandlerState(previous_handler) != 'NotInstalled': getcmd = 'updateCommand' # disable the old plugin if it exists if self.launchCommand(p.plugin_log, name, previous_version, 'disableCommand') == None: self.SetHandlerState(previous_handler, 'Enabled') Error('Unable to disable old plugin ' + name + ' version ' + previous_version) SimpleLog(p.plugin_log, 'Unable to disable old plugin ' + name + ' version ' + previous_version) else: self.SetHandlerState(previous_handler, 'Disabled') Log(name + ' version ' + previous_version + ' is disabled') SimpleLog(p.plugin_log, name + ' version ' + previous_version + ' is disabled') try: Log("Copy status file from old plugin dir to new") old_plg_dir = plg_dir new_plg_dir = os.path.join(LibDir, "{0}-{1}".format(name, version)) old_ext_status_dir = os.path.join(old_plg_dir, "status") new_ext_status_dir = os.path.join(new_plg_dir, "status") if os.path.isdir(old_ext_status_dir): for status_file in os.listdir(old_ext_status_dir): status_file_path = os.path.join(old_ext_status_dir, status_file) if os.path.isfile(status_file_path): shutil.copy2(status_file_path, new_ext_status_dir) mrseq_file = os.path.join(old_plg_dir, "mrseq") if os.path.isfile(mrseq_file): shutil.copy(mrseq_file, new_plg_dir) except Exception as e: Error("Failed to copy status file.") isupgradeSuccess = True if getcmd == 'updateCommand': if self.launchCommand(p.plugin_log, name, version, getcmd, previous_version) == None: Error('Update failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Update failed for ' + name + '-' + version) isupgradeSuccess = False else: Log('Update complete' + name + '-' + version) SimpleLog(p.plugin_log, 'Update complete' + name + '-' + version) # if we updated - call unistall for the old plugin if self.launchCommand(p.plugin_log, name, previous_version, 'uninstallCommand') == None: self.SetHandlerState(previous_handler, 'Installed') Error('Uninstall failed for ' + name + '-' + previous_version) SimpleLog(p.plugin_log, 'Uninstall failed for ' + name + '-' + previous_version) isupgradeSuccess = False else: self.SetHandlerState(previous_handler, 'NotInstalled') Log('Uninstall complete' + previous_handler) SimpleLog(p.plugin_log, 'Uninstall complete' + name + '-' + previous_version) try: # rm old plugin dir if os.path.isdir(plg_dir): shutil.rmtree(plg_dir) Log(name + '-' + previous_version + ' extension files deleted.') SimpleLog(p.plugin_log, name + '-' + previous_version + ' extension files deleted.') except Exception as e: Error("Failed to remove old plugin directory") AddExtensionEvent(name, WALAEventOperation.Upgrade, isupgradeSuccess, 0, previous_version) else: # run install if self.launchCommand(p.plugin_log, name, version, getcmd) == None: self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Installed') Log('Installation completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation completed for ' + name + '-' + version) # end if plg_dir == none or version > = prev # change incarnation of settings file so it knows how to name status... zip_dir = LibDir + "/" + name + '-' + version mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(p.plugin_log, 'HandlerManifest.json not found.') continue manifest = GetFileContents(mfile) p.setAttribute('manifestdata', manifest) config = '' seqNo = '0' if len(dom.getElementsByTagName("PluginSettings")) != 0: try: pslist = dom.getElementsByTagName("PluginSettings")[0].getElementsByTagName("Plugin") except: Error('Error parsing ExtensionsConfig.') SimpleLog(p.plugin_log, 'Error parsing ExtensionsConfig.') continue for ps in pslist: if name == ps.getAttribute("name") and version == ps.getAttribute("version"): Log("Found RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "Found RuntimeSettings for " + name + " V " + version) config = GetNodeTextData(ps.getElementsByTagName("RuntimeSettings")[0]) seqNo = ps.getElementsByTagName("RuntimeSettings")[0].getAttribute("seqNo") break if config == '': Error("No RuntimeSettings for " + name + " V " + version) SimpleLog(p.plugin_log, "No RuntimeSettings for " + name + " V " + version) SetFileContents(root + "/config/" + seqNo + ".settings", config) # state is still enable if (self.GetHandlerState(handler) == 'NotInstalled'): # run install first if true if self.launchCommand(p.plugin_log, name, version, 'installCommand') == None: self.SetHandlerState(handler, 'NotInstalled') Error('Installation failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Installed') Log('Installation completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Installation completed for ' + name + '-' + version) if (self.GetHandlerState(handler) != 'NotInstalled'): if self.launchCommand(p.plugin_log, name, version, 'enableCommand') == None: self.SetHandlerState(handler, 'Installed') Error('Enable failed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Enable failed for ' + name + '-' + version) else: self.SetHandlerState(handler, 'Enabled') Log('Enable completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Enable completed for ' + name + '-' + version) # this plugin processing is complete Log('Processing completed for ' + name + '-' + version) SimpleLog(p.plugin_log, 'Processing completed for ' + name + '-' + version) # end plugin processing loop Log('Finished processing ExtensionsConfig.xml') try: SimpleLog(p.plugin_log, 'Finished processing ExtensionsConfig.xml') except: pass return self def launchCommand(self, plugin_log, name, version, command, prev_version=None): commandToEventOperation = { "installCommand": WALAEventOperation.Install, "uninstallCommand": WALAEventOperation.UnIsntall, "updateCommand": WALAEventOperation.Upgrade, "enableCommand": WALAEventOperation.Enable, "disableCommand": WALAEventOperation.Disable, } isSuccess = True start = datetime.datetime.now() r = self.__launchCommandWithoutEventLog(plugin_log, name, version, command, prev_version) if r == None: isSuccess = False Duration = int((datetime.datetime.now() - start).seconds) if commandToEventOperation.get(command): AddExtensionEvent(name, commandToEventOperation[command], isSuccess, Duration, version) return r def __launchCommandWithoutEventLog(self, plugin_log, name, version, command, prev_version=None): # get the manifest and read the command mfile = None zip_dir = LibDir + "/" + name + '-' + version for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('HandlerManifest.json not found.') SimpleLog(plugin_log, 'HandlerManifest.json not found.') return None manifest = GetFileContents(mfile) try: jsn = json.loads(manifest) except: Error('Error parsing HandlerManifest.json.') SimpleLog(plugin_log, 'Error parsing HandlerManifest.json.') return None if type(jsn) == list: jsn = jsn[0] if jsn.has_key('handlerManifest'): cmd = jsn['handlerManifest'][command] else: Error('Key handlerManifest not found. Handler cannot be installed.') SimpleLog(plugin_log, 'Key handlerManifest not found. Handler cannot be installed.') if len(cmd) == 0: Error('Unable to read ' + command) SimpleLog(plugin_log, 'Unable to read ' + command) return None # for update we send the path of the old installation arg = '' if prev_version != None: arg = ' ' + LibDir + '/' + name + '-' + prev_version dirpath = os.path.dirname(mfile) LogIfVerbose('Command is ' + dirpath + '/' + cmd) # launch pid = None try: child = subprocess.Popen(dirpath + '/' + cmd + arg, shell=True, cwd=dirpath, stdout=subprocess.PIPE) except Exception as e: Error('Exception launching ' + cmd + str(e)) SimpleLog(plugin_log, 'Exception launching ' + cmd + str(e)) pid = child.pid if pid == None or pid < 1: ExtensionChildren.append((-1, root)) Error('Error launching ' + cmd + '.') SimpleLog(plugin_log, 'Error launching ' + cmd + '.') else: ExtensionChildren.append((pid, root)) Log("Spawned " + cmd + " PID " + str(pid)) SimpleLog(plugin_log, "Spawned " + cmd + " PID " + str(pid)) # wait until install/upgrade is finished timeout = 300 # 5 minutes retry = timeout / 5 while retry > 0 and child.poll() == None: LogIfVerbose(cmd + ' still running with PID ' + str(pid)) time.sleep(5) retry -= 1 if retry == 0: Error('Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) SimpleLog(plugin_log, 'Process exceeded timeout of ' + str(timeout) + ' seconds. Terminating process ' + str(pid)) os.kill(pid, 9) return None code = child.wait() if code == None or code != 0: Error('Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') SimpleLog(plugin_log, 'Process ' + str(pid) + ' returned non-zero exit code (' + str(code) + ')') return None Log(command + ' completed.') SimpleLog(plugin_log, command + ' completed.') return 0 def ReportHandlerStatus(self): """ Collect all status reports. """ # { "version": "1.0", "timestampUTC": "2014-03-31T21:28:58Z", # "aggregateStatus": { # "guestAgentStatus": { "version": "2.0.4PRE", "status": "Ready", "formattedMessage": { "lang": "en-US", "message": "GuestAgent is running and accepting new configurations." } }, # "handlerAggregateStatus": [{ # "handlerName": "ExampleHandlerLinux", "handlerVersion": "1.0", "status": "Ready", "runtimeSettingsStatus": { # "sequenceNumber": "2", "settingsStatus": { "timestampUTC": "2014-03-31T23:46:00Z", "status": { "name": "ExampleHandlerLinux", "operation": "Command Execution Finished", "configurationAppliedTime": "2014-03-31T23:46:00Z", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Finished executing command" }, # "substatus": [ # { "name": "StdOut", "status": "success", "formattedMessage": { "lang": "en-US", "message": "Goodbye world!" } }, # { "name": "StdErr", "status": "success", "formattedMessage": { "lang": "en-US", "message": "" } } # ] # } } } } # ] # }} try: incarnation = self.Extensions[0].getAttribute("goalStateIncarnation") except: Error('Error parsing attribute "goalStateIncarnation". Unable to send status reports') return -1 status = '' statuses = '' for p in self.Plugins: if p.getAttribute("state") == 'uninstall' or p.getAttribute("restricted") == 'true': continue version = p.getAttribute("version") name = p.getAttribute("name") if p.getAttribute("isJson") != 'true': LogIfVerbose("Plugin " + name + " version: " + version + " is not a JSON Extension. Skipping.") continue reportHeartbeat = False if len(p.getAttribute("manifestdata")) < 1: Error("Failed to get manifestdata.") else: reportHeartbeat = json.loads(p.getAttribute("manifestdata"))[0]['handlerManifest']['reportHeartbeat'] if len(statuses) > 0: statuses += ',' statuses += self.GenerateAggStatus(name, version, reportHeartbeat) tstamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) # header # agent state if provisioned == False: if provisionError == None: agent_state = 'Provisioning' agent_msg = 'Guest Agent is starting.' else: agent_state = 'Provisioning Error.' agent_msg = provisionError else: agent_state = 'Ready' agent_msg = 'GuestAgent is running and accepting new configurations.' status = '{"version":"1.0","timestampUTC":"' + tstamp + '","aggregateStatus":{"guestAgentStatus":{"version":"' + GuestAgentVersion + '","status":"' + agent_state + '","formattedMessage":{"lang":"en-US","message":"' + agent_msg + '"}},"handlerAggregateStatus":[' + statuses + ']}}' try: uri = GetNodeTextData(self.Extensions[0].getElementsByTagName("StatusUploadBlob")[0]).replace('&', '&') except: Error('Error parsing element "StatusUploadBlob". Unable to send status reports') return -1 LogIfVerbose('Status report ' + status + ' sent to ' + uri) return UploadStatusBlob(uri, status.encode("utf-8")) def GetCurrentSequenceNumber(self, plugin_base_dir): """ Get the settings file with biggest file number in config folder """ config_dir = os.path.join(plugin_base_dir, 'config') seq_no = 0 for subdir, dirs, files in os.walk(config_dir): for file in files: try: cur_seq_no = int(os.path.basename(file).split('.')[0]) if cur_seq_no > seq_no: seq_no = cur_seq_no except ValueError: continue return str(seq_no) def GenerateAggStatus(self, name, version, reportHeartbeat=False): """ Generate the status which Azure can understand by the status and heartbeat reported by extension """ plugin_base_dir = LibDir + '/' + name + '-' + version + '/' current_seq_no = self.GetCurrentSequenceNumber(plugin_base_dir) status_file = os.path.join(plugin_base_dir, 'status/', current_seq_no + '.status') heartbeat_file = os.path.join(plugin_base_dir, 'heartbeat.log') handler_state_file = os.path.join(plugin_base_dir, 'config', 'HandlerState') agg_state = 'NotReady' handler_state = None status_obj = None status_code = None formatted_message = None localized_message = None if os.path.exists(handler_state_file): handler_state = GetFileContents(handler_state_file).lower() if handler_state in HandlerStatusToAggStatus: agg_state = HandlerStatusToAggStatus[handler_state] if reportHeartbeat: if os.path.exists(heartbeat_file): d = int(time.time() - os.stat(heartbeat_file).st_mtime) if d > 600: # not updated for more than 10 min agg_state = 'Unresponsive' else: try: heartbeat = json.loads(GetFileContents(heartbeat_file))[0]["heartbeat"] agg_state = heartbeat.get("status") status_code = heartbeat.get("code") formatted_message = heartbeat.get("formattedMessage") localized_message = heartbeat.get("message") except: Error("Incorrect heartbeat file. Ignore it. ") else: agg_state = 'Unresponsive' # get status file reported by extension if os.path.exists(status_file): # raw status generated by extension is an array, get the first item and remove the unnecessary element try: status_obj = json.loads(GetFileContents(status_file))[0] del status_obj["version"] except: Error("Incorrect status file. Will NOT settingsStatus in settings. ") agg_status_obj = {"handlerName": name, "handlerVersion": version, "status": agg_state, "runtimeSettingsStatus": {"sequenceNumber": current_seq_no}} if status_obj: agg_status_obj["runtimeSettingsStatus"]["settingsStatus"] = status_obj if status_code != None: agg_status_obj["code"] = status_code if formatted_message: agg_status_obj["formattedMessage"] = formatted_message if localized_message: agg_status_obj["message"] = localized_message agg_status_string = json.dumps(agg_status_obj) LogIfVerbose("Handler Aggregated Status:" + agg_status_string) return agg_status_string def SetHandlerState(self, handler, state=''): zip_dir = LibDir + "/" + handler mfile = None for root, dirs, files in os.walk(zip_dir): for f in files: if f in ('HandlerManifest.json'): mfile = os.path.join(root, f) if mfile != None: break if mfile == None: Error('SetHandlerState(): HandlerManifest.json not found, cannot set HandlerState.') return None Log("SetHandlerState: " + handler + ", " + state) return SetFileContents(os.path.dirname(mfile) + '/config/HandlerState', state) def GetHandlerState(self, handler): handlerState = GetFileContents(handler + '/config/HandlerState') if (handlerState): return handlerState.rstrip('\r\n') else: return 'NotInstalled' class HostingEnvironmentConfig(object): """ Parse Hosting enviromnet config and store in HostingEnvironmentConfig.xml """ # # # # # # # # # # # # # # # # # # # # # # # # # # def __init__(self): self.reinitialize() def reinitialize(self): """ Reset Members. """ self.StoredCertificates = None self.Deployment = None self.Incarnation = None self.Role = None self.HostingEnvironmentSettings = None self.ApplicationSettings = None self.Certificates = None self.ResourceReferences = None def Parse(self, xmlText): """ Parse and create HostingEnvironmentConfig.xml. """ self.reinitialize() SetFileContents("HostingEnvironmentConfig.xml", xmlText) dom = xml.dom.minidom.parseString(xmlText) for a in ["HostingEnvironmentConfig", "Deployment", "Service", "ServiceInstance", "Incarnation", "Role", ]: if not dom.getElementsByTagName(a): Error("HostingEnvironmentConfig.Parse: Missing " + a) return None node = dom.childNodes[0] if node.localName != "HostingEnvironmentConfig": Error("HostingEnvironmentConfig.Parse: root not HostingEnvironmentConfig") return None self.ApplicationSettings = dom.getElementsByTagName("Setting") self.Certificates = dom.getElementsByTagName("StoredCertificate") return self def DecryptPassword(self, e): """ Return decrypted password. """ SetFileContents("password.p7m", "MIME-Version: 1.0\n" + "Content-Disposition: attachment; filename=\"password.p7m\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"password.p7m\"\n" + "Content-Transfer-Encoding: base64\n\n" + textwrap.fill(e, 64)) return RunGetOutput(Openssl + " cms -decrypt -in password.p7m -inkey Certificates.pem -recip Certificates.pem")[ 1] def ActivateResourceDisk(self): return MyDistro.ActivateResourceDisk() def Process(self): """ Execute ActivateResourceDisk in separate thread. Create the user account. Launch ConfigurationConsumer if specified in the config. """ no_thread = False if DiskActivated == False: for m in inspect.getmembers(MyDistro): if 'ActivateResourceDiskNoThread' in m: no_thread = True break if no_thread == True: MyDistro.ActivateResourceDiskNoThread() else: diskThread = threading.Thread(target=self.ActivateResourceDisk) diskThread.start() User = None Pass = None Expiration = None Thumbprint = None for b in self.ApplicationSettings: sname = b.getAttribute("name") svalue = b.getAttribute("value") if User != None and Pass != None: if User != "root" and User != "" and Pass != "": CreateAccount(User, Pass, Expiration, Thumbprint) else: Error("Not creating user account: " + User) for c in self.Certificates: csha1 = c.getAttribute("certificateId").split(':')[1].upper() if os.path.isfile(csha1 + ".prv"): Log("Private key with thumbprint: " + csha1 + " was retrieved.") if os.path.isfile(csha1 + ".crt"): Log("Public cert with thumbprint: " + csha1 + " was retrieved.") program = Config.get("Role.ConfigurationConsumer") if program != None: try: Children.append(subprocess.Popen([program, LibDir + "/HostingEnvironmentConfig.xml"])) except OSError as e: ErrorWithPrefix('HostingEnvironmentConfig.Process', 'Exception: ' + str(e) + ' occured launching ' + program) class WALAEvent(object): def __init__(self): self.providerId = "" self.eventId = 1 self.OpcodeName = "" self.KeywordName = "" self.TaskName = "" self.TenantName = "" self.RoleName = "" self.RoleInstanceName = "" self.ContainerId = "" self.ExecutionMode = "IAAS" self.OSVersion = "" self.GAVersion = "" self.RAM = 0 self.Processors = 0 def ToXml(self): strEventid = u''.format(self.eventId) strProviderid = u''.format(self.providerId) strRecordFormat = u'' strRecordNoQuoteFormat = u'' strMtStr = u'mt:wstr' strMtUInt64 = u'mt:uint64' strMtBool = u'mt:bool' strMtFloat = u'mt:float64' strEventsData = u"" for attName in self.__dict__: if attName in ["eventId", "filedCount", "providerId"]: continue attValue = self.__dict__[attName] if type(attValue) is int: strEventsData += strRecordFormat.format(attName, attValue, strMtUInt64) continue if type(attValue) is str: attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData += strRecordNoQuoteFormat.format(attName, attValue, strMtStr) continue if str(type(attValue)).count("'unicode'") > 0: attValue = xml.sax.saxutils.quoteattr(attValue) strEventsData += strRecordNoQuoteFormat.format(attName, attValue, strMtStr) continue if type(attValue) is bool: strEventsData += strRecordFormat.format(attName, attValue, strMtBool) continue if type(attValue) is float: strEventsData += strRecordFormat.format(attName, attValue, strMtFloat) continue Log("Warning: property " + attName + ":" + str(type(attValue)) + ":type" + str( type(attValue)) + "Can't convert to events data:" + ":type not supported") return u"{0}{1}{2}".format(strProviderid, strEventid, strEventsData) def Save(self): eventfolder = LibDir + "/events" if not os.path.exists(eventfolder): os.mkdir(eventfolder) os.chmod(eventfolder, 0o700) if len(os.listdir(eventfolder)) > 1000: raise Exception("WriteToFolder:Too many file under " + eventfolder + " exit") filename = os.path.join(eventfolder, str(int(time.time() * 1000000))) with open(filename + ".tmp", 'wb+') as hfile: hfile.write(self.ToXml().encode("utf-8")) os.rename(filename + ".tmp", filename + ".tld") class WALAEventOperation: HeartBeat = "HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" def AddExtensionEvent(name, op, isSuccess, duration=0, version="1.0", message="", type="", isInternal=False): event = ExtensionEvent() event.Name = name event.Version = version event.IsInternal = isInternal event.Operation = op event.OperationSuccess = isSuccess event.Message = message event.Duration = duration event.ExtensionType = type try: event.Save() except: Error("Error " + traceback.format_exc()) class ExtensionEvent(WALAEvent): def __init__(self): WALAEvent.__init__(self) self.eventId = 1 self.providerId = "69B669B9-4AF8-4C50-BDC4-6006FA76E975" self.Name = "" self.Version = "" self.IsInternal = False self.Operation = "" self.OperationSuccess = True self.ExtensionType = "" self.Message = "" self.Duration = 0 class WALAEventMonitor(WALAEvent): def __init__(self, postMethod): WALAEvent.__init__(self) self.post = postMethod self.sysInfo = {} self.eventdir = LibDir + "/events" self.issysteminfoinitilized = False def StartEventsLoop(self): eventThread = threading.Thread(target=self.EventsLoop) eventThread.setDaemon(True) eventThread.start() def EventsLoop(self): LastReportHeartBeatTime = datetime.datetime.min try: while True: if (datetime.datetime.now() - LastReportHeartBeatTime) > \ datetime.timedelta(minutes=30): LastReportHeartBeatTime = datetime.datetime.now() AddExtensionEvent(op=WALAEventOperation.HeartBeat, name="WALA", isSuccess=True) self.postNumbersInOneLoop = 0 self.CollectAndSendWALAEvents() time.sleep(60) except: Error("Exception in events loop:" + traceback.format_exc()) def SendEvent(self, providerid, events): dataFormat = u'{1}' \ '' data = dataFormat.format(providerid, events) self.post("/machine/?comp=telemetrydata", data) def CollectAndSendWALAEvents(self): if not os.path.exists(self.eventdir): return # Throtting, can't send more than 3 events in 15 seconds eventSendNumber = 0 eventFiles = os.listdir(self.eventdir) events = {} for file in eventFiles: if not file.endswith(".tld"): continue with open(os.path.join(self.eventdir, file), "rb") as hfile: # if fail to open or delete the file, throw exception xmlStr = hfile.read().decode("utf-8", 'ignore') os.remove(os.path.join(self.eventdir, file)) params = "" eventid = "" providerid = "" # if exception happen during process an event, catch it and continue try: xmlStr = self.AddSystemInfo(xmlStr) for node in xml.dom.minidom.parseString(xmlStr.encode("utf-8")).childNodes[0].childNodes: if node.tagName == "Param": params += node.toxml() if node.tagName == "Event": eventid = node.getAttribute("id") if node.tagName == "Provider": providerid = node.getAttribute("id") except: Error(traceback.format_exc()) continue if len(params) == 0 or len(eventid) == 0 or len(providerid) == 0: Error("Empty filed in params:" + params + " event id:" + eventid + " provider id:" + providerid) continue eventstr = u''.format(eventid, params) if not events.get(providerid): events[providerid] = "" if len(events[providerid]) > 0 and len(events.get(providerid) + eventstr) >= 63 * 1024: eventSendNumber += 1 self.SendEvent(providerid, events.get(providerid)) if eventSendNumber % 3 == 0: time.sleep(15) events[providerid] = "" if len(eventstr) >= 63 * 1024: Error("Signle event too large abort " + eventstr[:300]) continue events[providerid] = events.get(providerid) + eventstr for key in events.keys(): if len(events[key]) > 0: eventSendNumber += 1 self.SendEvent(key, events[key]) if eventSendNumber % 3 == 0: time.sleep(15) def AddSystemInfo(self, eventData): if not self.issysteminfoinitilized: self.issysteminfoinitilized = True try: self.sysInfo["OSVersion"] = platform.system() + ":" + "-".join(DistInfo(1)) + ":" + platform.release() self.sysInfo["GAVersion"] = GuestAgentVersion self.sysInfo["RAM"] = MyDistro.getTotalMemory() self.sysInfo["Processors"] = MyDistro.getProcessorCores() sharedConfig = xml.dom.minidom.parse("/var/lib/waagent/SharedConfig.xml").childNodes[0] hostEnvConfig = xml.dom.minidom.parse("/var/lib/waagent/HostingEnvironmentConfig.xml").childNodes[0] gfiles = RunGetOutput("ls -t /var/lib/waagent/GoalState.*.xml")[1] goalStateConfi = xml.dom.minidom.parse(gfiles.split("\n")[0]).childNodes[0] self.sysInfo["TenantName"] = hostEnvConfig.getElementsByTagName("Deployment")[0].getAttribute("name") self.sysInfo["RoleName"] = hostEnvConfig.getElementsByTagName("Role")[0].getAttribute("name") self.sysInfo["RoleInstanceName"] = sharedConfig.getElementsByTagName("Instance")[0].getAttribute("id") self.sysInfo["ContainerId"] = goalStateConfi.getElementsByTagName("ContainerId")[0].childNodes[ 0].nodeValue except: Error(traceback.format_exc()) eventObject = xml.dom.minidom.parseString(eventData.encode("utf-8")).childNodes[0] for node in eventObject.childNodes: if node.tagName == "Param": name = node.getAttribute("Name") if self.sysInfo.get(name): node.setAttribute("Value", xml.sax.saxutils.escape(str(self.sysInfo[name]))) return eventObject.toxml() WaagentLogrotate = """\ /var/log/waagent.log { monthly rotate 6 notifempty missingok } """ def GetMountPoint(mountlist, device): """ Example of mountlist: /dev/sda1 on / type ext4 (rw) proc on /proc type proc (rw) sysfs on /sys type sysfs (rw) devpts on /dev/pts type devpts (rw,gid=5,mode=620) tmpfs on /dev/shm type tmpfs (rw,rootcontext="system_u:object_r:tmpfs_t:s0") none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw) /dev/sdb1 on /mnt/resource type ext4 (rw) """ if (mountlist and device): for entry in mountlist.split('\n'): if (re.search(device, entry)): tokens = entry.split() # Return the 3rd column of this line return tokens[2] if len(tokens) > 2 else None return None def FindInLinuxKernelCmdline(option): """ Return match object if 'option' is present in the kernel boot options of the grub configuration. """ m = None matchs = r'^.*?' + MyDistro.grubKernelBootOptionsLine + r'.*?' + option + r'.*$' try: m = FindStringInFile(MyDistro.grubKernelBootOptionsFile, matchs) except IOError as e: Error( 'FindInLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str(e)) return m def AppendToLinuxKernelCmdline(option): """ Add 'option' to the kernel boot options of the grub configuration. """ if not FindInLinuxKernelCmdline(option): src = r'^(.*?' + MyDistro.grubKernelBootOptionsLine + r')(.*?)("?)$' rep = r'\1\2 ' + option + r'\3' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile, src, rep) except IOError as e: Error( 'AppendToLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str( e)) return 1 Run("update-grub", chk_err=False) return 0 def RemoveFromLinuxKernelCmdline(option): """ Remove 'option' to the kernel boot options of the grub configuration. """ if FindInLinuxKernelCmdline(option): src = r'^(.*?' + MyDistro.grubKernelBootOptionsLine + r'.*?)(' + option + r')(.*?)("?)$' rep = r'\1\3\4' try: ReplaceStringInFile(MyDistro.grubKernelBootOptionsFile, src, rep) except IOError as e: Error( 'RemoveFromLinuxKernelCmdline: Exception opening ' + MyDistro.grubKernelBootOptionsFile + 'Exception:' + str( e)) return 1 Run("update-grub", chk_err=False) return 0 def FindStringInFile(fname, matchs): """ Return match object if found in file. """ try: ms = re.compile(matchs) for l in (open(fname, 'r')).readlines(): m = re.search(ms, l) if m: return m except: raise return None def ReplaceStringInFile(fname, src, repl): """ Replace 'src' with 'repl' in file. """ try: sr = re.compile(src) if FindStringInFile(fname, src): updated = '' for l in (open(fname, 'r')).readlines(): n = re.sub(sr, repl, l) updated += n ReplaceFileContentsAtomic(fname, updated) except: raise return def ApplyVNUMAWorkaround(): """ If kernel version has NUMA bug, add 'numa=off' to kernel boot options. """ VersionParts = platform.release().replace('-', '.').split('.') if int(VersionParts[0]) > 2: return if int(VersionParts[1]) > 6: return if int(VersionParts[2]) > 37: return if AppendToLinuxKernelCmdline("numa=off") == 0: Log("Your kernel version " + platform.release() + " has a NUMA-related bug: NUMA has been disabled.") else: "Error adding 'numa=off'. NUMA has not been disabled." def RevertVNUMAWorkaround(): """ Remove 'numa=off' from kernel boot options. """ if RemoveFromLinuxKernelCmdline("numa=off") == 0: Log('NUMA has been re-enabled') else: Log('NUMA has not been re-enabled') def Install(): """ Install the agent service. Check dependencies. Create /etc/waagent.conf and move old version to /etc/waagent.conf.old Copy RulesFiles to /var/lib/waagent Create /etc/logrotate.d/waagent Set /etc/ssh/sshd_config ClientAliveInterval to 180 Call ApplyVNUMAWorkaround() """ if MyDistro.checkDependencies(): return 1 os.chmod(sys.argv[0], 0o755) SwitchCwd() for a in RulesFiles: if os.path.isfile(a): if os.path.isfile(GetLastPathElement(a)): os.remove(GetLastPathElement(a)) shutil.move(a, ".") Warn("Moved " + a + " -> " + LibDir + "/" + GetLastPathElement(a)) MyDistro.registerAgentService() if os.path.isfile("/etc/waagent.conf"): try: os.remove("/etc/waagent.conf.old") except: pass try: os.rename("/etc/waagent.conf", "/etc/waagent.conf.old") Warn("Existing /etc/waagent.conf has been renamed to /etc/waagent.conf.old") except: pass SetFileContents("/etc/waagent.conf", MyDistro.waagent_conf_file) SetFileContents("/etc/logrotate.d/waagent", WaagentLogrotate) filepath = "/etc/ssh/sshd_config" ReplaceFileContentsAtomic(filepath, "\n".join(filter(lambda a: not a.startswith("ClientAliveInterval"), GetFileContents(filepath).split( '\n'))) + "\nClientAliveInterval 180\n") Log("Configured SSH client probing to keep connections alive.") ApplyVNUMAWorkaround() return 0 def GetMyDistro(dist_class_name=''): """ Return MyDistro object. NOTE: Logging is not initialized at this point. """ if dist_class_name == '': if 'Linux' in platform.system(): Distro = DistInfo()[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() if 'NS-BSD' in platform.system(): Distro = platform.system() Distro = Distro.replace("-", "") Distro = Distro.strip('"') Distro = Distro.strip(' ') dist_class_name = Distro + 'Distro' if dist_class_name not in globals(): if ('SuSE'.lower() in Distro.lower()): Distro = 'SuSE' elif ('Ubuntu'.lower() in Distro.lower()): Distro = 'Ubuntu' elif ('centos'.lower() in Distro.lower() or 'big-ip'.lower() in Distro.lower()): Distro = 'centos' elif ('debian'.lower() in Distro.lower()): Distro = 'debian' elif ('oracle'.lower() in Distro.lower()): Distro = 'oracle' elif ('redhat'.lower() in Distro.lower()): Distro = 'redhat' elif ('Kali'.lower() in Distro.lower()): Distro = 'Kali' elif ('FreeBSD'.lower() in Distro.lower() or 'gaia'.lower() in Distro.lower() or 'panos'.lower() in Distro.lower()): Distro = 'FreeBSD' else: Distro = 'Default' dist_class_name = Distro + 'Distro' else: Distro = dist_class_name if dist_class_name not in globals(): ##print Distro + ' is not a supported distribution.' return None return globals()[dist_class_name]() # the distro class inside this module. def DistInfo(fullname=0): try: if 'FreeBSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if 'NS-BSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) distinfo = ['NS-BSD', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=0)) # remove trailing whitespace in distro name if(distinfo[0] == ''): osfile= open("/etc/os-release", "r") for line in osfile: lists=str(line).split("=") if(lists[0]== "NAME"): distname = lists[1].split("\"") distinfo[0] = distname[1] if(distinfo[0].lower() == "sles"): distinfo[0] = "SuSE" osfile.close() distinfo[0] = distinfo[0].strip() return distinfo if 'Linux' in platform.system(): distinfo = ["Default"] if "ubuntu" in platform.version().lower(): distinfo[0] = "Ubuntu" elif 'suse' in platform.version().lower(): distinfo[0] = "SuSE" elif 'centos' in platform.version().lower(): distinfo[0] = "centos" elif 'debian' in platform.version().lower(): distinfo[0] = "debian" elif 'oracle' in platform.version().lower(): distinfo[0] = "oracle" elif 'redhat' in platform.version().lower() or 'rhel' in platform.version().lower(): distinfo[0] = "redhat" elif 'kali' in platform.version().lower(): distinfo[0] = "Kali" return distinfo else: return platform.dist() except Exception as e: errMsg = 'Failed to retrieve the distinfo with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) logger.log(errMsg) distinfo = ['Abstract','1.0'] return distinfo def PackagedInstall(buildroot): """ Called from setup.py for use by RPM. Generic implementation Creates directories and files /etc/waagent.conf, /etc/init.d/waagent, /usr/sbin/waagent, /etc/logrotate.d/waagent, /etc/sudoers.d/waagent under buildroot. Copies generated files waagent.conf, into place and exits. """ MyDistro = GetMyDistro() if MyDistro == None: sys.exit(1) MyDistro.packagedInstall(buildroot) def LibraryInstall(buildroot): pass def Uninstall(): """ Uninstall the agent service. Copy RulesFiles back to original locations. Delete agent-related files. Call RevertVNUMAWorkaround(). """ SwitchCwd() for a in RulesFiles: if os.path.isfile(GetLastPathElement(a)): try: shutil.move(GetLastPathElement(a), a) Warn("Moved " + LibDir + "/" + GetLastPathElement(a) + " -> " + a) except: pass MyDistro.unregisterAgentService() MyDistro.uninstallDeleteFiles() RevertVNUMAWorkaround() return 0 def Deprovision(force, deluser): """ Remove user accounts created by provisioning. Disables root password if Provisioning.DeleteRootPassword = 'y' Stop agent service. Remove SSH host keys if they were generated by the provision. Set hostname to 'localhost.localdomain'. Delete cached system configuration files in /var/lib and /var/lib/waagent. """ # Append blank line at the end of file, so the ctime of this file is changed every time Run("echo ''>>" + MyDistro.getConfigurationPath()) SwitchCwd() print("WARNING! The waagent service will be stopped.") print("WARNING! All SSH host key pairs will be deleted.") print("WARNING! Cached DHCP leases will be deleted.") MyDistro.deprovisionWarnUser() delRootPass = Config.get("Provisioning.DeleteRootPassword") if delRootPass != None and delRootPass.lower().startswith("y"): print("WARNING! root password will be disabled. You will not be able to login as root.") try: input = raw_input except NameError: pass if force == False and not input('Do you want to proceed (y/n)? ').startswith('y'): return 1 MyDistro.stopAgentService() # Remove SSH host keys regenerateKeys = Config.get("Provisioning.RegenerateSshHostKeyPair") if regenerateKeys == None or regenerateKeys.lower().startswith("y"): Run("rm -f /etc/ssh/ssh_host_*key*") # Remove root password if delRootPass != None and delRootPass.lower().startswith("y"): MyDistro.deleteRootPassword() # Remove distribution specific networking configuration MyDistro.publishHostname('localhost.localdomain') MyDistro.deprovisionDeleteFiles() return 0 def SwitchCwd(): """ Switch to cwd to /var/lib/waagent. Create if not present. """ CreateDir(LibDir, "root", 0o700) os.chdir(LibDir) def Usage(): """ Print the arguments to waagent. """ print("usage: " + sys.argv[ 0] + " [-verbose] [-force] [-help|-install|-uninstall|-deprovision[+user]|-version|-serialconsole|-daemon]") return 0 def main(): """ Instantiate MyDistro, exit if distro class is not defined. Parse command-line arguments, exit with usage() on error. Instantiate ConfigurationProvider. Call appropriate non-daemon methods and exit. If daemon mode, enter Agent.Run() loop. """ if GuestAgentVersion == "": print("WARNING! This is a non-standard agent that does not include a valid version string.") if len(sys.argv) == 1: sys.exit(Usage()) LoggerInit('/var/log/waagent.log', '/dev/console') global LinuxDistro LinuxDistro = DistInfo()[0] global MyDistro MyDistro = GetMyDistro() if MyDistro == None: sys.exit(1) args = [] conf_file = None global force force = False for a in sys.argv[1:]: if re.match("^([-/]*)(help|usage|\\?)", a): sys.exit(Usage()) elif re.match("^([-/]*)version", a): print(GuestAgentVersion + " running on " + LinuxDistro) sys.exit(0) elif re.match("^([-/]*)verbose", a): myLogger.verbose = True elif re.match("^([-/]*)force", a): force = True elif re.match("^(?:[-/]*)conf=.+", a): conf_file = re.match("^(?:[-/]*)conf=(.+)", a).groups()[0] elif re.match("^([-/]*)(setup|install)", a): sys.exit(MyDistro.Install()) elif re.match("^([-/]*)(uninstall)", a): sys.exit(Uninstall()) else: args.append(a) global Config Config = ConfigurationProvider(conf_file) logfile = Config.get("Logs.File") if logfile is not None: myLogger.file_path = logfile logconsole = Config.get("Logs.Console") if logconsole is not None and logconsole.lower().startswith("n"): myLogger.con_path = None verbose = Config.get("Logs.Verbose") if verbose != None and verbose.lower().startswith("y"): myLogger.verbose = True global daemon daemon = False for a in args: if re.match("^([-/]*)deprovision\\+user", a): sys.exit(Deprovision(force, True)) elif re.match("^([-/]*)deprovision", a): sys.exit(Deprovision(force, False)) elif re.match("^([-/]*)daemon", a): daemon = True elif re.match("^([-/]*)serialconsole", a): AppendToLinuxKernelCmdline("console=ttyS0 earlyprintk=ttyS0") Log("Configured kernel to use ttyS0 as the boot console.") sys.exit(0) else: print("Invalid command line parameter:" + a) sys.exit(1) if daemon == False: sys.exit(Usage()) global modloaded modloaded = False while True: try: SwitchCwd() Log(GuestAgentLongName + " Version: " + GuestAgentVersion) if IsLinux(): Log("Linux Distribution Detected : " + LinuxDistro) except Exception as e: Error(traceback.format_exc()) Error("Exception: " + str(e)) Log("Restart agent in 15 seconds") time.sleep(15) if __name__ == '__main__': main() ================================================ FILE: VMBackup/main/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMBackup/main/backuplogger.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import os import string import time import traceback from blobwriter import BlobWriter from Utils.WAAgentUtil import waagent import sys class Backuplogger(object): def __init__(self, hutil): self.msg = '' self.con_path = '/dev/console' self.enforced_local_flag_value = True self.hutil = hutil self.prev_log = '' self.logging_off = False def enforce_local_flag(self, enforced_local): #Pause file logging during I/O freeze period by setting Enforced_local_flag_value to False #Enforced_local_flag_value is turned to False from True when Freeze Starts #Enforced_local_flag_value is turned to True from False when Freeze Ends if (self.hutil.get_intvalue_from_configfile('LoggingOff', 0) == 1): self.logging_off = True if (self.enforced_local_flag_value != False and enforced_local == False and self.logging_off == True): pass elif (self.enforced_local_flag_value != False and enforced_local == False): self.msg = self.msg + "================== Logs during Freeze Start ==============" + "\n" elif (self.enforced_local_flag_value == False and enforced_local == True): self.msg = self.msg + "================== Logs during Freeze End ==============" + "\n" self.commit_to_local() self.enforced_local_flag_value = enforced_local """description of class""" def log(self, msg, local=False, level='Info'): if(self.enforced_local_flag_value == False and self.logging_off == True): return WriteLog = self.hutil.get_strvalue_from_configfile('WriteLog','True') if (WriteLog == None or WriteLog == 'True'): log_msg = "" if sys.version_info > (3,): log_msg = self.log_to_con_py3(msg, level) else: log_msg = "{0} {1} {2} \n".format(str(datetime.datetime.utcnow()) , level , msg) if(self.enforced_local_flag_value != False): self.log_to_con(log_msg) if(self.enforced_local_flag_value == False): self.msg += log_msg else: self.hutil.log(str(msg),level) def log_to_con(self, msg): try: with open(self.con_path, "wb") as C : message = "".join(list(filter(lambda x : x in string.printable, msg))) C.write(message.encode('ascii','ignore')) except IOError as e: pass except Exception as e: pass def log_to_con_py3(self, msg, level='Info'): log_msg = "" try: if type(msg) is not str: msg = str(msg, errors="backslashreplace") time = datetime.datetime.utcnow().strftime(u'%Y/%m/%d %H:%M:%S.%f') log_msg = u"{0} {1} {2} \n".format(time , level , msg) log_msg= str(log_msg.encode('ascii', "backslashreplace"), encoding="ascii") if(self.enforced_local_flag_value != False): with open(self.con_path, "w") as C : C.write(log_msg) except IOError: pass except Exception as e: log_msg = "###### Exception in log_to_con_py3" return log_msg def commit(self, logbloburi): #commit to local file system first, then commit to the network. try: self.hutil.log(self.msg) self.msg = '' except Exception as e: pass try: self.commit_to_blob(logbloburi) except Exception as e: self.hutil.log('commit to blob failed') def commit_to_local(self): self.hutil.log(self.msg) self.msg = '' def commit_to_blob(self, logbloburi): UploadStatusAndLog = self.hutil.get_strvalue_from_configfile('UploadStatusAndLog','True') if (UploadStatusAndLog == None or UploadStatusAndLog == 'True'): log_to_blob = "" blobWriter = BlobWriter(self.hutil) # append the wala log at the end. try: # distro information if(self.hutil is not None and self.hutil.patching is not None and self.hutil.patching.distro_info is not None): distro_str = "" if(len(self.hutil.patching.distro_info)>1): distro_str = self.hutil.patching.distro_info[0] + " " + self.hutil.patching.distro_info[1] else: distro_str = self.hutil.patching.distro_info[0] self.msg = "Distro Info:" + distro_str + "\n" + self.msg self.msg = "Guest Agent Version is :" + waagent.GuestAgentVersion + "\n" + self.msg log_to_blob = str(self.hutil.fetch_log_message()) + "Tail of shell script log:" + str(self.hutil.get_shell_script_log()) except Exception as e: errMsg = 'Failed to get the waagent log with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.hutil.log(errMsg) blobWriter.WriteBlob(log_to_blob, logbloburi) def set_prev_log(self): self.prev_log = self.hutil.get_prev_log() ================================================ FILE: VMBackup/main/blobwriter.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import datetime import traceback try: import urlparse except ImportError: import urllib.parse as urlparse from common import CommonVariables from HttpUtil import HttpUtil from Utils import HandlerUtil class BlobProperties(): def __init__(self, blobType, contentLength): self.blobType = blobType self.contentLength = contentLength def __str__(self): return ' blobType: ' + str(self.blobType) + ' contentLength: ' + str(self.contentLength) class BlobWriter(object): blobEmptyDetails = {} """description of class""" def __init__(self, hutil): self.hutil = hutil """ network call should have retry. """ def WriteBlob(self,msg,blobUri): try: # get the blob type if(blobUri is not None): if (self.IsEmptyBlob(blobUri) == False): raise Exception("Cannot perform write operation on a non empty blob") blobProperties = self.GetBlobProperties(blobUri) blobType = "pageblob" if(blobProperties is not None): blobType = blobProperties.blobType if (str(blobType).lower() == "pageblob"): # Clear Page-Blob Contents self.ClearPageBlob(blobUri, blobProperties) # Write to Page-Blob self.WritePageBlob(msg, blobUri, blobProperties) else: self.WriteBlockBlob(msg, blobUri) else: self.hutil.log("bloburi is None") except Exception as e: self.hutil.log("Failed to committing the log with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) def WriteBlockBlob(self,msg,blobUri): retry_times = 3 while(retry_times > 0): try: # get the blob type if(blobUri is not None): http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri) headers = {} headers["x-ms-blob-type"] = 'BlockBlob' self.hutil.log(str(headers)) result = http_util.Call(method = 'PUT', sasuri_obj = sasuri_obj, data = msg, headers = headers, fallback_to_curl = True) if(result == CommonVariables.success): self.hutil.log("blob written succesfully") retry_times = 0 else: self.hutil.log("blob failed to write") HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.statusBlobUploadError, "true") else: self.hutil.log("bloburi is None") retry_times = 0 HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.statusBlobUploadError, "true") except Exception as e: HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.statusBlobUploadError, "true") self.hutil.log("Failed to committing the log with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) self.hutil.log("retry times is " + str(retry_times)) retry_times = retry_times - 1 def WritePageBlob(self, message, blobUri, blobProperties): if(blobUri is not None): retry_times = 3 while(retry_times > 0): msg = message try: PAGE_SIZE_BYTES = 512 PAGE_UPLOAD_LIMIT_BYTES = 4194304 # 4 MB STATUS_BLOB_LIMIT_BYTES = 10485760 # 10 MB http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri + '&comp=page') # Get Blob-properties to know content-length blobContentLength = int(blobProperties.contentLength) self.hutil.log("WritePageBlob: contentLength:"+str(blobContentLength)) maxMsgLen = STATUS_BLOB_LIMIT_BYTES if (blobContentLength > STATUS_BLOB_LIMIT_BYTES): maxMsgLen = blobContentLength msgLen = len(msg) self.hutil.log("WritePageBlob: msg length:"+str(msgLen)) if(len(msg) > maxMsgLen): msg = msg[msgLen-maxMsgLen:msgLen] msgLen = len(msg) self.hutil.log("WritePageBlob: msg length after aligning to maxMsgLen:"+str(msgLen)) if((msgLen % PAGE_SIZE_BYTES) != 0): # Add padding to message to make its legth multiple of 512 paddedLen = msgLen + (512 - (msgLen % PAGE_SIZE_BYTES)) msg = msg.ljust(paddedLen) msgLen = len(msg) self.hutil.log("WritePageBlob: msg length after aligning to page-size(512):"+str(msgLen)) if(blobContentLength < msgLen): # Try to resize blob to increase its size isSuccessful = self.try_resize_page_blob(blobUri, msgLen) if(isSuccessful == True): self.hutil.log("WritePageBlob: page-blob resized successfully new size(blobContentLength):"+str(msgLen)) blobContentLength = msgLen else: self.hutil.log("WritePageBlob: page-blob resize failed") if(msgLen > blobContentLength): msg = msg[msgLen-blobContentLength:msgLen] msgLen = len(msg) self.hutil.log("WritePageBlob: msg length after aligning to blobContentLength:"+str(msgLen)) # Write Pages result = CommonVariables.error bytes_sent = 0 while (bytes_sent < msgLen): bytes_remaining = msgLen - bytes_sent pageContent = None if(bytes_remaining > PAGE_UPLOAD_LIMIT_BYTES): # more than 4 MB pageContent = msg[bytes_sent:bytes_sent+PAGE_UPLOAD_LIMIT_BYTES] else: pageContent = msg[bytes_sent:msgLen] self.hutil.log("WritePageBlob: pageContentLen:"+str(len(pageContent))) result = self.put_page_update(pageContent, blobUri, bytes_sent) if(result == CommonVariables.success): self.hutil.log("WritePageBlob: page written succesfully") else: self.hutil.log("WritePageBlob: page failed to write") break bytes_sent = bytes_sent + len(pageContent) if(result == CommonVariables.success): self.hutil.log("WritePageBlob: page-blob written succesfully") retry_times = 0 else: self.hutil.log("WritePageBlob: page-blob failed to write") HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.statusBlobUploadError, "true") except Exception as e: HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.statusBlobUploadError, "true") self.hutil.log("WritePageBlob: Failed to write to page-blob with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) self.hutil.log("WritePageBlob: retry times is " + str(retry_times)) retry_times = retry_times - 1 else: self.hutil.log("WritePageBlob: bloburi is None") def ClearPageBlob(self, blobUri, blobProperties): if(blobUri is not None): retry_times = 3 while(retry_times > 0): try: http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri + '&comp=page') # Get Blob-properties to know content-length contentLength = int(blobProperties.contentLength) # Clear Pages if(contentLength > 0): result = self.put_page_clear(blobUri, 0, contentLength) if(result == CommonVariables.success): self.hutil.log("ClearPageBlob: page-blob cleared succesfully") retry_times = 0 else: self.hutil.log("ClearPageBlob: page-blob failed to clear") else: self.hutil.log("ClearPageBlob: page-blob contentLength is 0") retry_times = 0 except Exception as e: self.hutil.log("ClearPageBlob: Failed to clear to page-blob with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) self.hutil.log("ClearPageBlob: retry times is " + str(retry_times)) retry_times = retry_times - 1 else: self.hutil.log("ClearPageBlob: bloburi is None") def GetBlobType(self, blobUri): blobType = "BlockBlob" if(blobUri is not None): # Get Blob Properties blobProperties = self.GetBlobProperties(blobUri) if(blobProperties is not None): blobType = blobProperties.blobType self.hutil.log("GetBlobType: Blob-Type :"+str(blobType)) return blobType def GetBlobProperties(self, blobUri): blobProperties = None if(blobUri is not None): retry_times = 3 while(retry_times > 0): try: http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri) headers = {} result, httpResp, errMsg = http_util.HttpCallGetResponse('GET', sasuri_obj, None, headers = headers) self.hutil.log("GetBlobProperties: HttpCallGetResponse : result :" + str(result) + ", errMsg :" + str(errMsg)) blobProperties = self.httpresponse_get_blob_properties(httpResp) self.hutil.log("GetBlobProperties: blobProperties :" + str(blobProperties)) retry_times = 0 except Exception as e: self.hutil.log("GetBlobProperties: Failed to get blob properties with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) self.hutil.log("GetBlobProperties: retry times is " + str(retry_times)) retry_times = retry_times - 1 return blobProperties def put_page_clear(self, blobUri, pageBlobIndex, clearLength): http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri + '&comp=page') headers = {} headers["x-ms-page-write"] = 'clear' headers["x-ms-range"] = 'bytes={0}-{1}'.format(pageBlobIndex, pageBlobIndex + clearLength - 1) headers["Content-Length"] = 0 result = http_util.Call(method = 'PUT', sasuri_obj = sasuri_obj, data = None, headers = headers, fallback_to_curl = True) return result def put_page_update(self, pageContent, blobUri, pageBlobIndex): http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri + '&comp=page') headers = {} headers["x-ms-page-write"] = 'update' headers["x-ms-range"] = 'bytes={0}-{1}'.format(pageBlobIndex, pageBlobIndex + len(pageContent) - 1) headers["Content-Length"] = len(str(pageContent)) result = http_util.Call(method = 'PUT', sasuri_obj = sasuri_obj, data = pageContent, headers = headers, fallback_to_curl = True) return result def try_resize_page_blob(self, blobUri, size): isSuccessful = False if (size % 512 == 0): try: http_util = HttpUtil(self.hutil) sasuri_obj = urlparse.urlparse(blobUri + '&comp=properties') headers = {} headers["x-ms-blob-content-length"] = size headers["Content-Length"] = size result = http_util.Call(method = 'PUT', sasuri_obj = sasuri_obj, data = None, headers = headers, fallback_to_curl = True) if(result == CommonVariables.success): isSuccessful = True else: self.hutil.log("try_resize_page_blob: page-blob resize failed, size :"+str(size)+", result :"+str(result)) except Exception as e: self.hutil.log("try_resize_page_blob: failed to resize page-blob with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) else: self.hutil.log("try_resize_page_blob: invalid size : " + str(size)) return isSuccessful def httpresponse_get_blob_properties(self, httpResp): blobProperties = None if(httpResp != None): self.hutil.log("httpresponse_get_blob_properties: Blob-properties response status:"+str(httpResp.status)) if(httpResp.status == 200): resp_headers = httpResp.getheaders() blobType = httpResp.getheader('x-ms-blob-type') contentLength = httpResp.getheader('Content-Length') blobProperties = BlobProperties(blobType, contentLength) return blobProperties def VerifyIfBlobIsEmpty(self, blobUri): try: if(blobUri is not None): blobProperties = self.GetBlobProperties(blobUri) if (str(blobProperties.blobType).lower() == "pageblob"): self.hutil.log("VerifyIfBlobIsEmpty: Skipping for page blob") return True self.hutil.log("VerifyIfBlobIsEmpty: Content Length of blob: " + str(blobProperties.contentLength)) if(int(blobProperties.contentLength) == 0): return True else: return False else: self.hutil.log("VerifyIfBlobIsEmpty: bloburi is None") except Exception as e: self.hutil.log("VerifyIfBlobIsEmpty: Failed to get the blob content length with error: %s, stack trace: %s" % (str(e), traceback.format_exc())) return True def IsEmptyBlob(self, blobUri): if (bool(BlobWriter.blobEmptyDetails) == True): if (blobUri in BlobWriter.blobEmptyDetails.keys()): return BlobWriter.blobEmptyDetails[blobUri] isEmptyBlob = self.VerifyIfBlobIsEmpty(blobUri) BlobWriter.blobEmptyDetails[blobUri] = isEmptyBlob return isEmptyBlob ================================================ FILE: VMBackup/main/common.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class CommonVariables: azure_path = 'main/azure' utils_path_name = 'Utils' snapshot_service_path_name = "IaaSExtensionSnapshotService" extension_name = 'MyBackupTestLinuxInt' extension_version = "1.0.9120.0" extension_zip_version = "1" extension_type = extension_name extension_media_link = 'https://sopattna.blob.core.windows.net/extensions/' + extension_name + '-' + str(extension_version) + '.zip' extension_label = 'Windows Azure VMBackup Extension for Linux IaaS' extension_description = extension_label object_str = 'objectStr' logs_blob_uri = 'logsBlobUri' status_blob_uri = 'statusBlobUri' commandStartTimeUTCTicks = "commandStartTimeUTCTicks" task_id = 'taskId' command_to_execute = 'commandToExecute' iaas_vmbackup_command = 'snapshot' iaas_install_command = 'install' locale = 'locale' vmType = 'vmType' VmTypeV1 = 'microsoft.classiccompute/virtualmachines' VmTypeV2 = 'microsoft.compute/virtualmachines' customSettings = 'customSettings' statusBlobUploadError = 'statusBlobUploadError' TempStatusFileName = 'tempStatusFile.status' onlyLocalFilesystems = 'onlyLocalFilesystems' # -------------------- Dynamic Settings from CRP -------------------- isSnapshotTtlEnabled = 'isSnapshotTtlEnabled' useMccfForLad = 'useMccfForLad' useMccfToFetchDsasForAllDisks = 'useMccfToFetchDsasForAllDisks' enableSnapshotExtensionPolling = "EnableVMSnapshotExtensionPolling" isVmmdBlobIncluded = 'isVmmdBlobIncluded' key = 'Key' value = 'Value' snapshotTtlHeader = 'x-ms-snapshot-ttl-expiry-hours' snapshotTaskToken = 'snapshotTaskToken' snapshotCreator = 'snapshotCreator' hostStatusCodePreSnapshot = 'hostStatusCodePreSnapshot' hostStatusCodeDoSnapshot = 'hostStatusCodeDoSnapshot' guestExtension = 'guestExtension' backupHostService = 'backupHostService' includedDisks = 'includedDisks' isAnyDiskExcluded = 'isAnyDiskExcluded' dataDiskLunList = 'dataDiskLunList' isOSDiskIncluded = 'isOSDiskIncluded' isVmgsBlobIncluded = 'isVmgsBlobIncluded' isVMADEEnabled = 'isVMADEEnabled' isOsDiskADEEncrypted = 'isOsDiskADEEncrypted' areDataDisksADEEncrypted = 'areDataDisksADEEncrypted' diskEncryptionSettings = 'DiskEncryptionSettings' isAnyWADiskIncluded = 'isAnyWADiskIncluded' isAnyDirectDriveDiskIncluded = 'isAnyDirectDriveDiskIncluded' diskEncryptionKey = "x-ms-meta-DiskEncryptionSettings" instantAccessDurationMinutes = 'instantAccessDurationMinutes' onlyGuest = 'onlyGuest' firstGuestThenHost = 'firstGuestThenHost' firstHostThenGuest = 'firstHostThenGuest' onlyHost = 'onlyHost' SnapshotMethod = 'SnapshotMethod' IsAnySnapshotFailed = 'IsAnySnapshotFailed' SnapshotRateExceededFailureCount = 'SnapshotRateExceededFailureCount' status_transitioning = 'transitioning' status_warning = 'warning' status_success = 'success' status_error = 'error' unable_to_open_err_string= 'file open failed for some mount' """ error code definitions """ success_appconsistent = 0 success = 1 error = 2 SuccessAlreadyProcessedInput = 3 ExtensionTempTerminalState = 4 error_parameter = 11 error_12 = 12 error_wrong_time = 13 error_same_taskid = 14 error_http_failure = 15 FailedHandlerGuestAgentCertificateNotFound = 16 #error_upload_status_blob = 16 FailedRetryableSnapshotFailedNoNetwork = 76 FailedSnapshotLimitReached = 85 FailedRetryableSnapshotRateExceeded = 173 FailedFsFreezeFailed = 121 FailedFsFreezeTimeout = 122 FailedUnableToOpenMount = 123 FailedSafeFreezeBinaryNotFound = 124 FailedHostSnapshotRemoteServerError = 556 """ Pre-Post Plugin error code definitions """ PrePost_PluginStatus_Success = 0 PrePost_ScriptStatus_Success = 0 PrePost_ScriptStatus_Error = 1 PrePost_ScriptStatus_Warning = 2 FailedInvalidDataDiskLunList = 17 FailedPrepostPreScriptFailed = 300 FailedPrepostPostScriptFailed = 301 FailedPrepostPreScriptNotFound = 302 FailedPrepostPostScriptNotFound = 303 FailedPrepostPluginhostConfigParsing = 304 FailedPrepostPluginConfigParsing = 305 FailedPrepostPreScriptPermissionError = 306 FailedPrepostPostScriptPermissionError = 307 FailedPrepostPreScriptTimeout = 308 FailedPrepostPostScriptTimeout = 309 FailedPrepostPluginhostPreTimeout = 310 FailedPrepostPluginhostPostTimeout = 311 FailedPrepostCheckSumMismatch = 312 FailedPrepostPluginhostConfigNotFound = 313 FailedPrepostPluginhostConfigPermissionError = 314 FailedPrepostPluginhostConfigOwnershipError = 315 FailedPrepostPluginConfigNotFound = 316 FailedPrepostPluginConfigPermissionError = 317 FailedPrepostPluginConfigOwnershipError = 318 FailedGuestAgentInvokedCommandTooLate = 402 FailedWorkloadPreError = 500 FailedWorkloadConfParsingError = 501 FailedWorkloadInvalidRole = 502 FailedWorkloadInvalidWorkloadName = 503 FailedWorkloadPostError = 504 FailedWorkloadAuthorizationMissing = 505 FailedWorkloadConnectionError = 506 FailedWorkloadIPCDirectoryMissing = 507 FailedWorkloadDatabaseStatusChanged = 508 FailedWorkloadQuiescingError = 509 FailedWorkloadQuiescingTimeout = 510 FailedWorkloadDatabaseInNoArchiveLog = 511 FailedWorkloadLogModeChanged = 512 """ Consistency-Types """ consistency_none = 'none' consistency_crashConsistent = 'crashConsistent' consistency_fileSystemConsistent = 'fileSystemConsistent' consistency_applicationConsistent = 'applicationConsistent' @staticmethod def isTerminalStatus(status): return (status==CommonVariables.status_success or status==CommonVariables.status_error) class DeviceItem(object): def __init__(self): #NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL self.name = None self.type = None self.file_system = None self.mount_point = None self.label = None self.uuid = None self.model = None self.size = None def __str__(self): return "name:" + str(self.name) + " type:" + str(self.type) + " fstype:" + str(self.file_system) + " mountpoint:" + str(self.mount_point) + " label:" + str(self.label) + " model:" + str(self.model) ================================================ FILE: VMBackup/main/dhcpHandler.py ================================================ # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ and Openssl 1.0+ # Output of running the script on Windows: # The scheduled events endpoint IP address will be available as part of the system environment variable. # # Output of running the script on Linux: # The scheduled events endpoint IP address will be available as part of the environment variable for all users. import os import socket import struct import array import time from Utils import dhcpUtils import sys if sys.platform == 'win32': import _winreg as wreg from uuid import getnode as get_mac """ Defines dhcp exception """ class BaseError(Exception): """ Base error class. """ def __init__(self, errno, msg, inner=None): msg = u"({0}){1}".format(errno, msg) if inner is not None: msg = u"{0} \n inner error: {1}".format(msg, inner) super(BaseError, self).__init__(msg) class DhcpError(BaseError): """ Failed to handle dhcp response """ def __init__(self, msg=None, inner=None): super(DhcpError, self).__init__('000006', msg, inner) class DhcpHandler(object): def __init__(self, logger): self.osutil = dhcpUtils.DefaultOSUtil(logger) self.endpoint = None self.gateway = None self.routes = None self._request_broadcast = False self.skip_cache = False self.logger = logger def getHostEndoint(self): self.run() return self.endpoint def run(self): """ Send dhcp request """ self.send_dhcp_req() def _send_dhcp_req(self, request): __waiting_duration__ = [0, 10, 30, 60, 60] for duration in __waiting_duration__: try: self.osutil.allow_dhcp_broadcast() response = self.socket_send(request) self.validate_dhcp_resp(request, response) return response except DhcpError as e: self.logger.log("Failed to send DHCP request: " + str(e)) time.sleep(duration) return None def send_dhcp_req(self): """ Build dhcp request with mac addr Configure route to allow dhcp traffic Stop dhcp service if necessary """ self.logger.log("Sending dhcp request") mac_addr = self.osutil.get_mac_in_bytes() req = self.build_dhcp_request(mac_addr, self._request_broadcast) resp = self._send_dhcp_req(req) if resp is None: raise DhcpError("Failed to receive dhcp response.") self.endpoint, self.gateway, self.routes = self.parse_dhcp_resp(resp) self.logger.log('Scheduled Events endpoint IP address:' + self.endpoint) def validate_dhcp_resp(self, request, response): bytes_recv = len(response) if bytes_recv < 0xF6: self.logger.log("HandleDhcpResponse: Too few bytes received: " + str(bytes_recv)) return False self.logger.log("BytesReceived:{0}" + str(hex(bytes_recv))) #self.logger.log("DHCP response:{0}" + dhcpUtils.hex_dump(response, bytes_recv)) # check transactionId, cookie, MAC address cookie should never mismatch # transactionId and MAC address may mismatch if we see a response # meant from another machine if not dhcpUtils.compare_bytes(request, response, 0xEC, 4): self.logger.log("Cookie not match:\nsend={0},\nreceive={1}".format(dhcpUtils.hex_dump3(request, 0xEC, 4), dhcpUtils.hex_dump3(response, 0xEC, 4))) raise DhcpError("Cookie in dhcp respones doesn't match the request") if not dhcpUtils.compare_bytes(request, response, 4, 4): self.logger.log("TransactionID not match:\nsend={0},\nreceive={1}".format(dhcpUtils.hex_dump3(request, 4, 4), dhcpUtils.hex_dump3(response, 4, 4))) raise DhcpError("TransactionID in dhcp respones " "doesn't match the request") if not dhcpUtils.compare_bytes(request, response, 0x1C, 6): self.logger.log("Mac Address not match:\nsend={0},\nreceive={1}".format(dhcpUtils.hex_dump3(request, 0x1C, 6), dhcpUtils.hex_dump3(response, 0x1C, 6))) raise DhcpError("Mac Addr in dhcp respones " "doesn't match the request") def parse_route(self, response, option, i, length, bytes_recv): # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx self.logger.log("Routes at offset: {0} with length:{1}".format(hex(i), hex(length))) routes = [] if length < 5: self.logger.log("Data too small for option:{0}" + str(option)) j = i + 2 while j < (i + length + 2): mask_len_bits = dhcpUtils.str_to_ord(response[j]) mask_len_bytes = (((mask_len_bits + 7) & ~7) >> 3) mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - mask_len_bits)) j += 1 net = dhcpUtils.unpack_big_endian(response, j, mask_len_bytes) net <<= (32 - mask_len_bytes * 8) net &= mask j += mask_len_bytes gateway = dhcpUtils.unpack_big_endian(response, j, 4) j += 4 routes.append((net, mask, gateway)) if j != (i + length + 2): self.logger.log("Unable to parse routes") return routes def parse_ip_addr(self, response, option, i, length, bytes_recv): if i + 5 < bytes_recv: if length != 4: self.logger.log("Endpoint or Default Gateway not 4 bytes") return None addr = dhcpUtils.unpack_big_endian(response, i + 2, 4) ip_addr = dhcpUtils.int_to_ip4_addr(addr) return ip_addr else: self.logger.log("Data too small for option: " + str(option)) return None def parse_dhcp_resp(self, response): """ Parse DHCP response: Returns endpoint server or None on error. """ self.logger.log('Parsing Dhcp response') bytes_recv = len(response) endpoint = None gateway = None routes = None # Walk all the returned options, parsing out what we need, ignoring the # others. We need the custom option 245 to find the the endpoint we talk to # options 3 for default gateway and 249 for routes; 255 is end. i = 0xF0 # offset to first option while i < bytes_recv: option = dhcpUtils.str_to_ord(response[i]) length = 0 if (i + 1) < bytes_recv: length = dhcpUtils.str_to_ord(response[i + 1]) self.logger.log("DHCP option {0} at offset:{1} with length:{2}".format(hex(option), hex(i), hex(length))) if option == 255: self.logger.log("DHCP packet ended at offset:{0}".format(hex(i))) break elif option == 249: routes = self.parse_route(response, option, i, length, bytes_recv) elif option == 3: gateway = self.parse_ip_addr(response, option, i, length, bytes_recv) self.logger.log("Default gateway:{0}, at {1}".format(gateway, hex(i))) elif option == 245: endpoint = self.parse_ip_addr(response, option, i, length, bytes_recv) self.logger.log("Azure scheduled events endpoint IP:{0}, at {1}".format(endpoint, hex(i))) else: self.logger.log("Skipping DHCP option:{0} at {1} with length {2}".format(hex(option), hex(i), hex(length))) i += length + 2 return endpoint, gateway, routes def socket_send(self, request): sock = None try: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("0.0.0.0", 68)) sock.sendto(request, ("", 67)) sock.settimeout(10) self.logger.log("Send DHCP request: Setting socket.timeout=10, entering recv") response = sock.recv(1024) return response except IOError as e: raise DhcpError("{0}".format(e)) finally: if sock is not None: sock.close() def build_dhcp_request(self, mac_addr, request_broadcast): """ Build DHCP request string. """ # # typedef struct _DHCP { # UINT8 Opcode; /* op: BOOTREQUEST or BOOTREPLY */ # UINT8 HardwareAddressType; /* htype: ethernet */ # UINT8 HardwareAddressLength; /* hlen: 6 (48 bit mac address) */ # UINT8 Hops; /* hops: 0 */ # UINT8 TransactionID[4]; /* xid: random */ # UINT8 Seconds[2]; /* secs: 0 */ # UINT8 Flags[2]; /* flags: 0 or 0x8000 for broadcast*/ # UINT8 ClientIpAddress[4]; /* ciaddr: 0 */ # UINT8 YourIpAddress[4]; /* yiaddr: 0 */ # UINT8 ServerIpAddress[4]; /* siaddr: 0 */ # UINT8 RelayAgentIpAddress[4]; /* giaddr: 0 */ # UINT8 ClientHardwareAddress[16]; /* chaddr: 6 byte eth MAC address */ # UINT8 ServerName[64]; /* sname: 0 */ # UINT8 BootFileName[128]; /* file: 0 */ # UINT8 MagicCookie[4]; /* 99 130 83 99 */ # /* 0x63 0x82 0x53 0x63 */ # /* options -- hard code ours */ # # UINT8 MessageTypeCode; /* 53 */ # UINT8 MessageTypeLength; /* 1 */ # UINT8 MessageType; /* 1 for DISCOVER */ # UINT8 End; /* 255 */ # } DHCP; # # tuple of 244 zeros # (struct.pack_into would be good here, but requires Python 2.5) request = [0] * 244 trans_id = self.gen_trans_id() # Opcode = 1 # HardwareAddressType = 1 (ethernet/MAC) # HardwareAddressLength = 6 (ethernet/MAC/48 bits) for a in range(0, 3): request[a] = [1, 1, 6][a] # fill in transaction id (random number to ensure response matches request) for a in range(0, 4): request[4 + a] = dhcpUtils.str_to_ord(trans_id[a]) self.logger.log("BuildDhcpRequest: transactionId:{0},{1:04x}".format(dhcpUtils.hex_dump2(trans_id), dhcpUtils.unpack_big_endian(request, 4, 4))) if request_broadcast: # set broadcast flag to true to request the dhcp sever # to respond to a boradcast address, # this is useful when user dhclient fails. request[0x0A] = 0x80; # fill in ClientHardwareAddress for a in range(0, 6): request[0x1C + a] = dhcpUtils.str_to_ord(mac_addr[a]) # DHCP Magic Cookie: 99, 130, 83, 99 # MessageTypeCode = 53 DHCP Message Type # MessageTypeLength = 1 # MessageType = DHCPDISCOVER # End = 255 DHCP_END for a in range(0, 8): request[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a] return array.array("B", request) def gen_trans_id(self): return os.urandom(4) ================================================ FILE: VMBackup/main/freezesnapshotter.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os try: import urlparse as urlparser except ImportError: import urllib.parse as urlparser import traceback import datetime try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import multiprocessing as mp import time import json from common import CommonVariables from HttpUtil import HttpUtil from Utils import Status from Utils import HandlerUtil from fsfreezer import FsFreezer from guestsnapshotter import GuestSnapshotter from hostsnapshotter import HostSnapshotter from Utils import HostSnapshotObjects import ExtensionErrorCodeHelper # need to be implemented in next release #from dhcpHandler import DhcpHandler class FreezeSnapshotter(object): """description of class""" def __init__(self, logger, hutil , freezer, g_fsfreeze_on, para_parser, takeCrashConsistentSnapshot): self.logger = logger self.configfile = '/etc/azure/vmbackup.conf' self.hutil = hutil self.freezer = freezer self.g_fsfreeze_on = g_fsfreeze_on self.para_parser = para_parser if(para_parser.snapshotTaskToken == None): para_parser.snapshotTaskToken = '' #making snapshot string empty when snapshotTaskToken is null self.logger.log('snapshotTaskToken : ' + str(para_parser.snapshotTaskToken)) self.takeSnapshotFrom = CommonVariables.firstHostThenGuest self.isManaged = False self.taskId = self.para_parser.taskId self.hostIp = '168.63.129.16' self.additional_headers = [] self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.success self.takeCrashConsistentSnapshot = takeCrashConsistentSnapshot self.logger.log('FreezeSnapshotter : takeCrashConsistentSnapshot = ' + str(self.takeCrashConsistentSnapshot)) #implement in next release ''' # fetching wireserver IP from DHCP self.dhcpHandlerObj = None try: self.dhcpHandlerObj = DhcpHandler(self.logger) self.hostIp = self.dhcpHandlerObj.getHostEndoint() except Exception as e: errorMsg = "Failed to get hostIp from DHCP with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg, True, 'Error') self.hostIp = '168.63.129.16' ''' self.logger.log( "hostIp : " + self.hostIp) try: if(para_parser.customSettings != None and para_parser.customSettings != ''): self.logger.log('customSettings : ' + str(para_parser.customSettings)) customSettings = json.loads(para_parser.customSettings) snapshotMethodConfigValue = self.hutil.get_strvalue_from_configfile(CommonVariables.SnapshotMethod,customSettings['takeSnapshotFrom']) self.logger.log('snapshotMethodConfigValue : ' + str(snapshotMethodConfigValue)) if snapshotMethodConfigValue != None and snapshotMethodConfigValue != '': self.takeSnapshotFrom = snapshotMethodConfigValue else: self.takeSnapshotFrom = customSettings['takeSnapshotFrom'] self.isManaged = customSettings['isManagedVm'] if( "backupTaskId" in customSettings.keys()): self.taskId = customSettings["backupTaskId"] waDiskLunList= [] if "waDiskLunList" in customSettings.keys() and customSettings['waDiskLunList'] != None : waDiskLunList = customSettings['waDiskLunList'] self.logger.log('WA Disk Lun List ' + str(waDiskLunList)) if waDiskLunList!=None and waDiskLunList.count != 0 and para_parser.includeLunList!=None and para_parser.includeLunList.count!=0 : for crpLunNo in para_parser.includeLunList : if crpLunNo in waDiskLunList : self.logger.log('WA disk is present on the VM. Setting the snapshot mode to onlyHost.') self.takeSnapshotFrom = CommonVariables.onlyHost break else: self.logger.log('CustomSettings is null in extension input.') snapshotMethodConfigValue = self.hutil.get_strvalue_from_configfile(CommonVariables.SnapshotMethod,CommonVariables.firstHostThenGuest) self.logger.log('snapshotMethodConfigValue : ' + str(snapshotMethodConfigValue)) if snapshotMethodConfigValue != None and snapshotMethodConfigValue != '': self.takeSnapshotFrom = snapshotMethodConfigValue except Exception as e: errMsg = 'Failed to serialize customSettings with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') self.isManaged = True try: if(para_parser.includedDisks != None and CommonVariables.isAnyWADiskIncluded in para_parser.includedDisks.keys()): if (para_parser.includedDisks[CommonVariables.isAnyWADiskIncluded] == True): self.logger.log('WA disk is included. Setting the snapshot mode to onlyHost.') self.takeSnapshotFrom = CommonVariables.onlyHost if(para_parser.includedDisks != None and CommonVariables.isVmgsBlobIncluded in para_parser.includedDisks.keys()): if (para_parser.includedDisks[CommonVariables.isVmgsBlobIncluded] == True): self.logger.log('Vmgs Blob is included. Setting the snapshot mode to onlyHost.') self.takeSnapshotFrom = CommonVariables.onlyHost if(para_parser.includedDisks != None and CommonVariables.isAnyDirectDriveDiskIncluded in para_parser.includedDisks.keys()): if (para_parser.includedDisks[CommonVariables.isAnyDirectDriveDiskIncluded] == True): self.logger.log('DirectDrive Disk is included. Setting the snapshot mode to onlyHost.') self.takeSnapshotFrom = CommonVariables.onlyHost if(para_parser.includedDisks != None and CommonVariables.isAnyDiskExcluded in para_parser.includedDisks): # IsAnyDiskExcluded is true, but the included LUN list is empty in the extensions input if (para_parser.includedDisks[CommonVariables.isAnyDiskExcluded] == True and (para_parser.includeLunList == None or para_parser.includeLunList.count == 0)): # When the direct drive disk is part of the disks. so, failing the extension as snapshot can't be taken via Guest if( CommonVariables.isAnyDirectDriveDiskIncluded in para_parser.includedDisks and para_parser.includedDisks[CommonVariables.isAnyDirectDriveDiskIncluded] == True): errMsg = 'DirectDrive disk is included, so the host must create the snapshot. IsAnyDiskExcluded is true, but, the included LUN list is empty in the extension input, '\ 'which is not allowed for host DoSnapshot. Thus, failing the extension run.' self.logger.log(errMsg, True, 'Error') self.hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList) # When the VmgsBlob is part of the disks. so, failing the extension as snapshot can't be taken via Guest elif( CommonVariables.isVmgsBlobIncluded in para_parser.includedDisks and para_parser.includedDisks[CommonVariables.isVmgsBlobIncluded] == True): errMsg = 'VmgsBlob is included, so the host must create the snapshot. IsAnyDiskExcluded is true, but, the included LUN list is empty in the extension input, '\ 'which is not allowed for host DoSnapshot. Thus, failing the extension run.' self.logger.log(errMsg, True, 'Error') self.hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList) # When the WADisk is part of the disks. so, failing the extension as snapshot can't be taken via Guest elif( CommonVariables.isAnyWADiskIncluded in para_parser.includedDisks and para_parser.includedDisks[CommonVariables.isAnyWADiskIncluded] == True): errMsg = 'WADisk is included, so the host must create the snapshot. IsAnyDiskExcluded is true, but, the included LUN list is empty in the extension input, '\ 'which is not allowed for host DoSnapshot. Thus, failing the extension run.' self.logger.log(errMsg, True, 'Error') self.hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList) else: self.logger.log('Some disks are excluded from backup and LUN list is not present. Setting the snapshot mode to onlyGuest.') self.takeSnapshotFrom = CommonVariables.onlyGuest #Check if snapshot uri has special characters if self.hutil.UriHasSpecialCharacters(self.para_parser.blobs): self.logger.log('Some disk blob Uris have special characters.') except Exception as e: errMsg = 'Failed to process flags in includedDisks with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') self.logger.log('[FreezeSnapshotter] isManaged flag : ' + str(self.isManaged)) def doFreezeSnapshot(self): run_result = CommonVariables.success run_status = 'success' all_failed = False unable_to_sleep = False """ Do Not remove below HttpUtil object creation. This is to ensure HttpUtil singleton object is created before freeze.""" http_util = HttpUtil(self.logger) if(self.takeSnapshotFrom == CommonVariables.onlyGuest): run_result, run_status, blob_snapshot_info_array, all_failed, all_snapshots_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromGuest() elif(self.takeSnapshotFrom == CommonVariables.firstGuestThenHost): run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromFirstGuestThenHost() elif(self.takeSnapshotFrom == CommonVariables.firstHostThenGuest): run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromFirstHostThenGuest() elif(self.takeSnapshotFrom == CommonVariables.onlyHost): run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromOnlyHost() else : self.logger.log('Snapshot method did not match any listed type, taking firstHostThenGuest as default') run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromFirstHostThenGuest() self.logger.log('doFreezeSnapshot : run_result - {0} run_status - {1} all_failed - {2} unable_to_sleep - {3} is_inconsistent - {4} values post snapshot'.format(str(run_result), str(run_status), str(all_failed), str(unable_to_sleep), str(is_inconsistent))) if (run_result == CommonVariables.success): run_result, run_status = self.updateErrorCode(blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent) snapshot_info_array = self.update_snapshotinfoarray(blob_snapshot_info_array) if not (run_result == CommonVariables.success): self.hutil.SetExtErrorCode(self.extensionErrorCode) return run_result, run_status, snapshot_info_array def update_snapshotinfoarray(self, blob_snapshot_info_array): snapshot_info_array = [] self.logger.log('updating snapshot info array from blob snapshot info') if blob_snapshot_info_array != None and blob_snapshot_info_array !=[]: for blob_snapshot_info in blob_snapshot_info_array: if blob_snapshot_info != None: self.logger.log("IsSuccessful:{0}, SnapshotUri:{1}, ErrorMessage:{2}".format(blob_snapshot_info.isSuccessful, blob_snapshot_info.snapshotUri, blob_snapshot_info.errorMessage)) # Sample SnapshotBlobUri Format # UltraDisk: https://md-dd-e470ba041280442aabc964b73060460b.z48.disk.storage.azure.net/disks/e470ba04-1280-442a-abc9-64b73060460b/snapshots?snapshotId=C8E4AC08-8BA6-46B6-973A-BD6C0BD22CD7 # Standard Disk: https://md-pbhlk3l5mb1q.z27.blob.storage.azure.net:443/zzvgfnxr4fgw/abcd?snapshot=2021-07-31T10:07:37.6596865Z blobUri = blob_snapshot_info.snapshotUri if(blob_snapshot_info.snapshotUri): endIndexOfBlobUri = blob_snapshot_info.snapshotUri.find('?') if(blob_snapshot_info.ddSnapshotIdentifier != None): endIndexOfBlobUri = blob_snapshot_info.snapshotUri.find("/snapshots") if(endIndexOfBlobUri != -1): blobUri = blobUri[0:endIndexOfBlobUri] else: self.logger.log("Unable to find end index of blobUri in snapshotUri. Assigning default snapshotUri to blobUri. This {0} a DirectDrive disk".format("is" if(blob_snapshot_info.ddSnapshotIdentifier != None) else "is not")) self.logger.log("blobUri : {0}".format(blobUri)) ddSnapshotIdentifierInfo = None if(blob_snapshot_info.ddSnapshotIdentifier != None): # snapshotUri is None for DD Disks. It is populated only for XStore disks blob_snapshot_info.snapshotUri = None creationTimeStr = '\\/Date(' + blob_snapshot_info.ddSnapshotIdentifier.creationTime + ')\\/' creationTimeObj = Status.CreationTime(creationTimeStr, 0) ddSnapshotIdentifierInfo = Status.DirectDriveSnapshotIdentifier(creationTimeObj, blob_snapshot_info.ddSnapshotIdentifier.id, blob_snapshot_info.ddSnapshotIdentifier.token, blob_snapshot_info.ddSnapshotIdentifier.instantAccessDurationMinutes) self.logger.log("DDSnapshotIdentifier Information to CRP- creationTime : {0}, id : {1}, token : {2}, instantAccessDurationMinutes : {3}".format( ddSnapshotIdentifierInfo.creationTime.DateTime, ddSnapshotIdentifierInfo.id, ddSnapshotIdentifierInfo.token, ddSnapshotIdentifierInfo.instantAccessDurationMinutes if ddSnapshotIdentifierInfo.instantAccessDurationMinutes is not None else 'Not Set')) else: self.logger.log("No DD Snapshot Identifier Found. Hence directDriveSnapshotIdentifier will be Null") snapshot_info_array.append(Status.SnapshotInfoObj(blob_snapshot_info.isSuccessful, blob_snapshot_info.snapshotUri, blob_snapshot_info.errorMessage, blobUri, ddSnapshotIdentifierInfo)) return snapshot_info_array def updateErrorCode(self, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent): run_result = CommonVariables.success any_failed = False run_status = 'success' if unable_to_sleep: run_result = CommonVariables.error run_status = 'error' error_msg = 'T:S Machine unable to sleep' self.logger.log(error_msg, True, 'Error') elif is_inconsistent == True : run_result = CommonVariables.error run_status = 'error' error_msg = 'Snapshots are inconsistent' self.logger.log(error_msg, True, 'Error') elif blob_snapshot_info_array != None: for blob_snapshot_info in blob_snapshot_info_array: if blob_snapshot_info != None and blob_snapshot_info.errorMessage != None : if 'The rate of snapshot blob calls is exceeded' in blob_snapshot_info.errorMessage: run_result = CommonVariables.FailedRetryableSnapshotRateExceeded run_status = 'error' error_msg = 'Retrying when snapshot failed with SnapshotRateExceeded' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotRateExceeded self.logger.log(error_msg, True, 'Error') break elif 'The snapshot count against this blob has been exceeded' in blob_snapshot_info.errorMessage: run_result = CommonVariables.FailedSnapshotLimitReached run_status = 'error' error_msg = 'T:S Enable failed with FailedSnapshotLimitReached errror' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedSnapshotLimitReached error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) self.logger.log(error_msg, True, 'Error') break elif blob_snapshot_info.isSuccessful == False and not all_failed: any_failed = True elif blob_snapshot_info != None and blob_snapshot_info.isSuccessful == False: any_failed = True if all_failed: doSnapshot_status = HandlerUtil.HandlerUtility.get_telemetry_data(CommonVariables.hostStatusCodeDoSnapshot) preSnapshot_status = HandlerUtil.HandlerUtility.get_telemetry_data(CommonVariables.hostStatusCodePreSnapshot) if run_result == CommonVariables.success and doSnapshot_status == "556" and preSnapshot_status == "200": run_result = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedHostSnapshotRemoteServerError error_msg = 'T:S Enable failed with FailedHostSnapshotRemoteServerError error' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedHostSnapshotRemoteServerError else: run_result = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedNoNetwork error_msg = 'T:S Enable failed with FailedRetryableSnapshotFailedNoNetwork error' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedNoNetwork error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) self.logger.log(error_msg, True, 'Error') elif run_result == CommonVariables.success and any_failed: run_result = CommonVariables.FailedRetryableSnapshotFailedNoNetwork error_msg = 'T:S Enable failed with FailedRetryableSnapshotFailedRestrictedNetwork errror' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedRestrictedNetwork error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) run_status = 'error' self.logger.log(error_msg, True, 'Error') return run_result, run_status def freeze(self): try: timeout = self.hutil.get_intvalue_from_configfile('timeout',60) self.logger.log('T:S freeze, timeout value ' + str(timeout)) time_before_freeze = datetime.datetime.now() freeze_result,timedout = self.freezer.freeze_safe(timeout) time_after_freeze = datetime.datetime.now() freezeTimeTaken = time_after_freeze-time_before_freeze self.logger.log('T:S ***** freeze, time_before_freeze=' + str(time_before_freeze) + ", time_after_freeze=" + str(time_after_freeze) + ", freezeTimeTaken=" + str(freezeTimeTaken)) HandlerUtil.HandlerUtility.add_to_telemetery_data("FreezeTime", str(time_after_freeze-time_before_freeze-datetime.timedelta(seconds=5))) run_result = CommonVariables.success run_status = 'success' all_failed= False is_inconsistent = False self.logger.log('T:S freeze result ' + str(freeze_result) + ', timedout :' + str(timedout)) if (timedout == True): run_result = CommonVariables.FailedFsFreezeTimeout run_status = 'error' error_msg = 'T:S ###### Enable failed with error: freeze took longer than timeout' self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableFsFreezeTimeout error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) self.logger.log(error_msg, True, 'Error') elif(freeze_result is not None and len(freeze_result.errors) > 0 and CommonVariables.unable_to_open_err_string in str(freeze_result)): run_result = CommonVariables.FailedUnableToOpenMount run_status = 'error' error_msg = 'T:S Enable failed with error: ' + str(freeze_result) self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableUnableToOpenMount error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) self.logger.log(error_msg, True, 'Warning') elif(freeze_result is not None and len(freeze_result.errors) > 0): run_result = CommonVariables.FailedFsFreezeFailed run_status = 'error' error_msg = 'T:S Enable failed with error: ' + str(freeze_result) self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableFsFreezeFailed error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(self.extensionErrorCode) self.logger.log(error_msg, True, 'Warning') except Exception as e: errMsg = 'Failed to do the freeze with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') run_result = CommonVariables.error run_status = 'error' return run_result, run_status def takeSnapshotFromGuest(self): run_result = CommonVariables.success run_status = 'success' all_failed= False is_inconsistent = False unable_to_sleep = False blob_snapshot_info_array = None all_snapshots_failed = False try: if( self.para_parser.blobs == None or len(self.para_parser.blobs) == 0) : run_result = CommonVariables.FailedRetryableSnapshotFailedNoNetwork run_status = 'error' error_msg = 'T:S taking snapshot failed as blobs are empty or none' self.logger.log(error_msg, True, 'Error') all_failed = True all_snapshots_failed = True return run_result, run_status, blob_snapshot_info_array, all_failed, all_snapshots_failed, unable_to_sleep, is_inconsistent if(self.para_parser.isVMADEEnabled == True and self.para_parser.blobs != None): # fetch the disk encryption details self.fetchDiskBlobMetadata() if self.g_fsfreeze_on : run_result, run_status = self.freeze() if(self.para_parser is not None and self.is_command_timedout(self.para_parser) == True): self.hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedGuestAgentInvokedCommandTooLate) run_result = CommonVariables.FailedGuestAgentInvokedCommandTooLate run_status = 'error' all_failed = True all_snapshots_failed = True self.logger.log('T:S takeSnapshotFromGuest : Thawing as failing due to CRP timeout', True, 'Error') self.freezer.thaw_safe() elif(run_result == CommonVariables.success or self.takeCrashConsistentSnapshot == True): HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.snapshotCreator, CommonVariables.guestExtension) snap_shotter = GuestSnapshotter(self.logger, self.hutil) self.logger.log('T:S doing snapshot now...') time_before_snapshot = datetime.datetime.now() snapshot_result, blob_snapshot_info_array, all_failed, is_inconsistent, unable_to_sleep, all_snapshots_failed = snap_shotter.snapshotall(self.para_parser, self.freezer, self.g_fsfreeze_on) time_after_snapshot = datetime.datetime.now() snapshotTimeTaken = time_after_snapshot-time_before_snapshot self.logger.log('T:S ***** takeSnapshotFromGuest, time_before_snapshot=' + str(time_before_snapshot) + ", time_after_snapshot=" + str(time_after_snapshot) + ", snapshotTimeTaken=" + str(snapshotTimeTaken)) HandlerUtil.HandlerUtility.add_to_telemetery_data("snapshotTimeTaken", str(snapshotTimeTaken)) self.logger.log('T:S snapshotall ends...', True) except Exception as e: errMsg = 'Failed to do the snapshot with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg, True, 'Error') run_result = CommonVariables.error run_status = 'error' return run_result, run_status, blob_snapshot_info_array, all_failed, all_snapshots_failed, unable_to_sleep, is_inconsistent def takeSnapshotFromFirstGuestThenHost(self): run_result = CommonVariables.success run_status = 'success' all_failed= False is_inconsistent = False unable_to_sleep = False blob_snapshot_info_array = None all_snapshots_failed = False run_result, run_status, blob_snapshot_info_array, all_failed, all_snapshots_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromGuest() if(all_snapshots_failed): try: #to make sure binary is thawed self.logger.log('[takeSnapshotFromFirstGuestThenHost] : Thawing again post the guest snapshotting failure') self.freezer.thaw_safe() except Exception as e: self.logger.log('[takeSnapshotFromFirstGuestThenHost] : Exception in Thaw %s, stack trace: %s' % (str(e), traceback.format_exc())) run_result, run_status, blob_snapshot_info_array,all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromOnlyHost() return run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent def takeSnapshotFromFirstHostThenGuest(self): run_result = CommonVariables.success run_status = 'success' all_failed= False is_inconsistent = False unable_to_sleep = False blob_snapshot_info_array = None snap_shotter = HostSnapshotter(self.logger, self.hostIp) pre_snapshot_statuscode, responseBody = snap_shotter.pre_snapshot(self.para_parser, self.taskId, True) if(pre_snapshot_statuscode == 200 or pre_snapshot_statuscode == 201): run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromOnlyHost() else: run_result, run_status, blob_snapshot_info_array, all_failed, all_snapshots_failed, unable_to_sleep, is_inconsistent = self.takeSnapshotFromGuest() if all_snapshots_failed and run_result != CommonVariables.success: self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedNoNetwork elif run_result != CommonVariables.success : self.extensionErrorCode = ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedRetryableSnapshotFailedRestrictedNetwork return run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent def takeSnapshotFromOnlyHost(self): run_result = CommonVariables.success run_status = 'success' all_failed= False is_inconsistent = False unable_to_sleep = False blob_snapshot_info_array = None self.logger.log('Taking Snapshot through Host') HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.snapshotCreator, CommonVariables.backupHostService) if self.g_fsfreeze_on : run_result, run_status = self.freeze() if(self.para_parser is not None and self.is_command_timedout(self.para_parser) == True): self.hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedGuestAgentInvokedCommandTooLate) run_result = CommonVariables.FailedGuestAgentInvokedCommandTooLate run_status = 'error' all_failed = True self.logger.log('T:S takeSnapshotFromOnlyHost : Thawing as failing due to CRP timeout', True, 'Error') self.freezer.thaw_safe() elif(run_result == CommonVariables.success or self.takeCrashConsistentSnapshot == True): snap_shotter = HostSnapshotter(self.logger, self.hostIp) self.logger.log('T:S doing snapshot now...') time_before_snapshot = datetime.datetime.now() blob_snapshot_info_array, all_failed, is_inconsistent, unable_to_sleep = snap_shotter.snapshotall(self.para_parser, self.freezer, self.g_fsfreeze_on, self.taskId) time_after_snapshot = datetime.datetime.now() snapshotTimeTaken = time_after_snapshot-time_before_snapshot self.logger.log('T:S takeSnapshotFromHost, time_before_snapshot=' + str(time_before_snapshot) + ", time_after_snapshot=" + str(time_after_snapshot) + ", snapshotTimeTaken=" + str(snapshotTimeTaken)) HandlerUtil.HandlerUtility.add_to_telemetery_data("snapshotTimeTaken", str(snapshotTimeTaken)) self.logger.log('T:S snapshotall ends...', True) return run_result, run_status, blob_snapshot_info_array, all_failed, unable_to_sleep, is_inconsistent def is_command_timedout(self, para_parser): result = False dateTimeNow = datetime.datetime.utcnow() try: try: snap_shotter = HostSnapshotter(self.logger, self.hostIp) pre_snapshot_statuscode,responseBody = snap_shotter.pre_snapshot(self.para_parser, self.taskId) if(int(pre_snapshot_statuscode) == 200 or int(pre_snapshot_statuscode) == 201) and (responseBody != None and responseBody != "") : response = json.loads(responseBody) dateTimeNow = datetime.datetime(response['responseTime']['year'], response['responseTime']['month'], response['responseTime']['day'], response['responseTime']['hour'], response['responseTime']['minute'], response['responseTime']['second']) self.logger.log('Date and time extracted from pre-snapshot request: '+ str(dateTimeNow)) except Exception as e: self.logger.log('Error in getting Host time falling back to using system time. Exception %s, stack trace: %s' % (str(e), traceback.format_exc())) if(para_parser is not None and para_parser.commandStartTimeUTCTicks is not None and para_parser.commandStartTimeUTCTicks != ""): utcTicksLong = int(para_parser.commandStartTimeUTCTicks) self.logger.log('utcTicks in long format' + str(utcTicksLong)) commandStartTime = self.convert_time(utcTicksLong) self.logger.log('command start time is ' + str(commandStartTime) + " and utcNow is " + str(dateTimeNow)) timespan = dateTimeNow - commandStartTime MAX_TIMESPAN = 140 * 60 # in seconds total_span_in_seconds = self.timedelta_total_seconds(timespan) self.logger.log('timespan: ' + str(timespan) + ', total_span_in_seconds: ' + str(total_span_in_seconds) + ', MAX_TIMESPAN: ' + str(MAX_TIMESPAN)) if total_span_in_seconds > MAX_TIMESPAN : self.logger.log('CRP timeout limit has reached, should abort.') result = True except Exception as e: self.logger.log('T:S is_command_timedout : Exception %s, stack trace: %s' % (str(e), traceback.format_exc())) return result def convert_time(self, utcTicks): return datetime.datetime(1, 1, 1) + datetime.timedelta(microseconds = utcTicks / 10) def timedelta_total_seconds(self, delta): if not hasattr(datetime.timedelta, 'total_seconds'): return delta.days * 86400 + delta.seconds else: return delta.total_seconds() def fetchDiskBlobMetadata(self): headers = self.generate_headers() http_util = HttpUtil(self.logger) for blob in self.para_parser.blobs: sasuri_obj = urlparser.urlparse(blob + '&comp=metadata') result, httpResp, errMsg = http_util.HttpCallGetResponse('GET', sasuri_obj, None, headers = headers) if(result == CommonVariables.success and httpResp != None): resp_headers = httpResp.getheaders() key = CommonVariables.diskEncryptionKey value = "" for k,v in resp_headers: if key == k: value = str(v) break self.additional_headers.append((key,value)) self.para_parser.disk_encryption_details = self.additional_headers def generate_headers(self): """Generates headers for the request using SAS token, x-ms-date, and x-ms-version.""" headers = { "x-ms-date": datetime.datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'), "x-ms-version": "2018-03-28" } return headers ================================================ FILE: VMBackup/main/fsfreezer.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess from mounts import Mounts import datetime import threading import os import platform import time import sys import signal import traceback import threading import fcntl from common import CommonVariables from Utils.ResourceDiskUtil import ResourceDiskUtil def thread_for_binary(self,args): self.logger.log("Thread for binary is called",True) time.sleep(3) self.logger.log("Waited in thread for 3 seconds",True) self.logger.log("****** 1. Starting Freeze Binary ",True) self.child = subprocess.Popen(args,stdout=subprocess.PIPE) self.logger.log("Binary subprocess Created",True) class FreezeError(object): def __init__(self): self.errorcode = None self.fstype = None self.path = None def __str__(self): return "errorcode:" + str(self.errorcode) + " fstype:" + str(self.fstype) + " path" + str(self.path) class FreezeResult(object): def __init__(self): self.errors = [] def __str__(self): error_str = "" for error in self.errors: error_str+=(str(error)) + "\n" return error_str class FreezeHandler(object): def __init__(self,logger,hutil): # sig_handle valid values(0:nothing done,1: freezed successfully, 2:freeze failed) self.sig_handle = 0 self.child= None self.logger=logger self.hutil = hutil def sigusr1_handler(self,signal,frame): self.logger.log('freezed',False) self.logger.log("****** 4. Freeze Completed (Signal=1 received)",False) self.sig_handle=1 def sigchld_handler(self,signal,frame): self.logger.log('some child process terminated') if(self.child is not None and self.child.poll() is not None): self.logger.log("binary child terminated",True) self.logger.log("****** 9. Binary Process completed (Signal=2 received)",True) self.sig_handle=2 def reset_signals(self): self.sig_handle = 0 self.child= None def startproc(self,args): binary_thread = threading.Thread(target=thread_for_binary, args=[self, args]) binary_thread.start() SafeFreezeWaitInSecondsDefault = 66 proc_sleep_time = self.hutil.get_intvalue_from_configfile('SafeFreezeWaitInSeconds',SafeFreezeWaitInSecondsDefault) for i in range(0,(int(proc_sleep_time/2))): if(self.sig_handle==0): self.logger.log("inside while with sig_handle "+str(self.sig_handle)) time.sleep(2) else: break self.logger.log("Binary output for signal handled: "+str(self.sig_handle)) return self.sig_handle def signal_receiver(self): signal.signal(signal.SIGUSR1,self.sigusr1_handler) signal.signal(signal.SIGCHLD,self.sigchld_handler) class FsFreezer: def __init__(self, patching, logger, hutil): """ """ self.patching = patching self.logger = logger self.hutil = hutil self.safeFreezeFolderPath = "safefreeze/bin/safefreeze" self.isArm64Machine = False self.file_exists = True # Flag to indiacte safeFreeze Binary presence try: platformMachine = platform.machine() architectureFromUname = os.uname()[-1] self.logger.log("platformMachine : " + str(platformMachine) + " architectureFromUname : " + str(architectureFromUname)) if((platformMachine != None and (platformMachine.startswith("aarch64") or platformMachine.startswith("arm64"))) or (architectureFromUname != None and (architectureFromUname.startswith("aarch64") or architectureFromUname.startswith("arm64")))): self.isArm64Machine = True except Exception as e: errorMsg = "Unable to fetch machine processor architecture, error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg, 'Error') if(self.isArm64Machine == True): self.logger.log("isArm64Machine : " + str(self.isArm64Machine) + " Using ARM64 safefreeze binary") self.safeFreezeFolderPath = "safefreezeArm64/bin/safefreeze" else: self.logger.log("isArm64Machine : " + str(self.isArm64Machine) + " Using x64 safefreeze binary") self.safeFreezeFolderPath = "safefreeze/bin/safefreeze" self.logger.log("Checking for the safefreeze binary") self.check_if_file_exists(self.safeFreezeFolderPath) try: self.mounts = Mounts(patching = self.patching, logger = self.logger) except Exception as e: errMsg='Failed to retrieve mount points, Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg,True,'Warning') self.logger.log(str(e), True) self.mounts = None self.frozen_items = set() self.unfrozen_items = set() self.freeze_handler = FreezeHandler(self.logger, self.hutil) self.mount_open_failed = False resource_disk = ResourceDiskUtil(patching = patching, logger = logger) self.resource_disk_mount_point = resource_disk.get_resource_disk_mount_point() self.skip_freeze = True self.isAquireLockSucceeded = True self.getLockRetry = 0 self.maxGetLockRetry = 5 self.safeFreezelockFile = None def check_if_file_exists(self, relative_path): full_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), relative_path) self.logger.log("path of the file"+ str(full_path)) self.file_exists = os.path.exists(full_path) self.logger.log("file path exists " + str(self.file_exists)) def should_skip(self, mount): if(self.resource_disk_mount_point is not None and mount.mount_point == self.resource_disk_mount_point): return True elif((mount.fstype == 'ext3' or mount.fstype == 'ext4' or mount.fstype == 'xfs' or mount.fstype == 'btrfs') and mount.type != 'loop' ): return False else: return True def freeze_safe(self,timeout): self.root_seen = False error_msg='' timedout = False self.skip_freeze = True mounts_to_skip = None try: mounts_to_skip = self.hutil.get_strvalue_from_configfile('MountsToSkip','') self.logger.log("skipped mount :" + str(mounts_to_skip), True) mounts_list_to_skip = mounts_to_skip.split(',') except Exception as e: errMsg='Failed to read from config, Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) self.logger.log(errMsg,True,'Warning') try: freeze_result = FreezeResult() freezebin=os.path.join(os.getcwd(),os.path.dirname(__file__),self.safeFreezeFolderPath) args=[freezebin,str(timeout)] no_mount_found = True for mount in self.mounts.mounts: self.logger.log("fsfreeze mount :" + str(mount.mount_point), True) if(mount.mount_point == '/'): self.root_seen = True self.root_mount = mount elif(mount.mount_point not in mounts_list_to_skip and not self.should_skip(mount)): if(self.skip_freeze == True): self.skip_freeze = False args.append(str(mount.mount_point)) if(self.root_seen and not self.should_skip(self.root_mount)): if(self.skip_freeze == True): self.skip_freeze = False args.append('/') self.logger.log("skip freeze is : " + str(self.skip_freeze), True) if(self.skip_freeze == True): return freeze_result,timedout self.logger.log("arg : " + str(args),True) self.freeze_handler.reset_signals() self.freeze_handler.signal_receiver() self.logger.log("proceeded for accepting signals", True) if(mounts_to_skip == '/'): #for continue logging to avoid out of memory issue self.logger.enforce_local_flag(True) else: self.logger.enforce_local_flag(False) start_time = datetime.datetime.utcnow() while self.getLockRetry < self.maxGetLockRetry: try: if not os.path.isdir('/etc/azure'): os.mkdir('/etc/azure') if not os.path.isdir('/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock'): os.mkdir('/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock') self.safeFreezelockFile = open("/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock/SafeFreezeLockFile","w") self.logger.log("/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock/SafeFreezeLockFile file opened Sucessfully",True) try: #isAquiredLockSucceeded lock will only be false if there is a issue in taking lock. #For all other issue like faliure in creating file, not enough space in disk it will be true. so that we can proceed with the backup self.isAquireLockSucceeded = False fcntl.lockf(self.safeFreezelockFile, fcntl.LOCK_EX | fcntl.LOCK_NB) self.logger.log("Aquiring lock succeeded",True) self.isAquireLockSucceeded = True break except Exception as ex: self.safeFreezelockFile.close() self.logger.log("Failed to aquire lock: %s, stack trace: %s" % (str(ex), traceback.format_exc()),True) raise ex except Exception as e: self.logger.log("Failed to open file or aquire lock: %s, stack trace: %s" % (str(e), traceback.format_exc()),True) self.getLockRetry= self.getLockRetry + 1 time.sleep(1) if(self.getLockRetry == self.maxGetLockRetry - 1): time.sleep(30) self.logger.log("Retry to aquire lock count: "+ str(self.getLockRetry),True) end_time = datetime.datetime.utcnow() self.logger.log("Wait time to aquire lock "+ str(end_time - start_time),True) # sig_handle = None if (self.isAquireLockSucceeded == True): self.logger.log("Aquired Lock Successful") sig_handle=self.freeze_handler.startproc(args) self.logger.log("freeze_safe after returning from startproc : sig_handle="+str(sig_handle)) if(sig_handle != 1): if (self.freeze_handler.child is not None): self.log_binary_output() if (sig_handle == 0): timedout = True error_msg="freeze timed-out" freeze_result.errors.append(error_msg) self.logger.log(error_msg, True, 'Error') elif (self.mount_open_failed == True): error_msg=CommonVariables.unable_to_open_err_string freeze_result.errors.append(error_msg) self.logger.log(error_msg, True, 'Error') elif (self.isAquireLockSucceeded == False): error_msg="Mount Points already freezed by some other processor" freeze_result.errors.append(error_msg) self.logger.log(error_msg,True,'Error') else: error_msg="freeze failed for some mount" freeze_result.errors.append(error_msg) self.logger.log(error_msg, True, 'Error') except Exception as e: self.logger.enforce_local_flag(True) error_msg='freeze failed for some mount with exception, Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) freeze_result.errors.append(error_msg) self.logger.log(error_msg, True, 'Error') return freeze_result,timedout def releaseFileLock(self): if (self.isAquireLockSucceeded == True): try: fcntl.lockf(self.safeFreezelockFile, fcntl.LOCK_UN) self.safeFreezelockFile.close() except Exception as e: self.logger.log("Failed to unlock: %s, stack trace: %s" % (str(e), traceback.format_exc()),True) try: os.remove("/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock/SafeFreezeLockFile") except Exception as e: self.logger.log("Failed to delete /etc/azure/MicrosoftRecoverySvcsSafeFreezeLock/SafeFreezeLockFile file: %s, stack trace: %s" % (str(e), traceback.format_exc()),True) def thaw_safe(self): thaw_result = None unable_to_sleep = False try: thaw_result = FreezeResult() if(self.skip_freeze == True): return thaw_result, unable_to_sleep if(self.freeze_handler.child is None): self.logger.log("child already completed", True) self.logger.log("****** 7. Error - Binary Process Already Completed", True) error_msg = 'snapshot result inconsistent' thaw_result.errors.append(error_msg) elif(self.freeze_handler.child.poll() is None): self.logger.log("child process still running") self.logger.log("****** 7. Sending Thaw Signal to Binary") self.freeze_handler.child.send_signal(signal.SIGUSR1) for i in range(0,30): if(self.freeze_handler.child.poll() is None): self.logger.log("child still running sigusr1 sent") time.sleep(1) else: break self.logger.enforce_local_flag(True) self.log_binary_output() if(self.freeze_handler.child.returncode!=0): error_msg = 'snapshot result inconsistent as child returns with failure' thaw_result.errors.append(error_msg) self.logger.log(error_msg, True, 'Error') else: self.logger.log("Binary output after process end when no thaw sent: ", True) if(self.freeze_handler.child.returncode==2): error_msg = 'Unable to execute sleep' thaw_result.errors.append(error_msg) unable_to_sleep = True else: error_msg = 'snapshot result inconsistent' thaw_result.errors.append(error_msg) self.logger.enforce_local_flag(True) self.log_binary_output() self.logger.log(error_msg, True, 'Error') self.logger.enforce_local_flag(True) finally: self.releaseFileLock() return thaw_result, unable_to_sleep def log_binary_output(self): self.logger.log("============== Binary output traces start ================= ", True) while True: line=self.freeze_handler.child.stdout.readline() if sys.version_info > (3,): line = str(line, encoding='utf-8', errors="backslashreplace") else: line = str(line) if("Failed to open:" in line): self.mount_open_failed = True if(line != ''): self.logger.log(line.rstrip(), True) else: break self.logger.log("============== Binary output traces end ================= ", True) ================================================ FILE: VMBackup/main/guestsnapshotter.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os try: import urlparse as urlparser except ImportError: import urllib.parse as urlparser import traceback import datetime try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import multiprocessing as mp from common import CommonVariables from HttpUtil import HttpUtil from Utils import Status from Utils import HandlerUtil from fsfreezer import FsFreezer from Utils import HostSnapshotObjects class SnapshotInfoIndexerObj(): def __init__(self, index, isSuccessful, snapshotTs, errorMessage): self.index = index self.isSuccessful = isSuccessful self.snapshotTs = snapshotTs self.errorMessage = errorMessage self.statusCode = 500 def __str__(self): return 'index: ' + str(self.index) + ' isSuccessful: ' + str(self.isSuccessful) + ' snapshotTs: ' + str(self.snapshotTs) + ' errorMessage: ' + str(self.errorMessage) + ' statusCode: ' + str(self.statusCode) class SnapshotError(object): def __init__(self): self.errorcode = CommonVariables.success self.sasuri = None def __str__(self): return 'errorcode: ' + str(self.errorcode) class SnapshotResult(object): def __init__(self): self.errors = [] def __str__(self): error_str = "" for error in self.errors: error_str+=(str(error)) + "\n" return error_str class GuestSnapshotter(object): """description of class""" def __init__(self, logger, hutil): self.logger = logger self.configfile='/etc/azure/vmbackup.conf' self.hutil = hutil def snapshot(self, sasuri, sasuri_index, settings, meta_data, snapshot_result_error, snapshot_info_indexer_queue, global_logger, global_error_logger, disk_encryption_details = None): temp_logger='' error_logger='' snapshot_error = SnapshotError() snapshot_info_indexer = SnapshotInfoIndexerObj(sasuri_index, False, None, None) if(sasuri is None): error_logger = error_logger + str(datetime.datetime.utcnow()) + " Failed to do the snapshot because sasuri is none " snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri try: sasuri_obj = urlparser.urlparse(sasuri) if(sasuri_obj is None or sasuri_obj.hostname is None): error_logger = error_logger + str(datetime.datetime.utcnow()) + " Failed to parse the sasuri " snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri else: start_time = datetime.datetime.utcnow() body_content = '' headers = {} headers["Content-Length"] = '0' if(meta_data is not None): for meta in meta_data: key = meta['Key'] value = meta['Value'] headers["x-ms-meta-" + key] = value temp_logger = temp_logger + str(headers) if(disk_encryption_details is not None and len(disk_encryption_details) >= 2 and disk_encryption_details[0] and disk_encryption_details[1]): headers[disk_encryption_details[0]] = disk_encryption_details[1] self.logger.log("appending disk_encryption_details as part of headers while taking a snapshot") if(CommonVariables.isSnapshotTtlEnabled in settings and settings[CommonVariables.isSnapshotTtlEnabled]): self.logger.log("Not passing the TTL header via Guest path though it is enabled") http_util = HttpUtil(self.logger) sasuri_obj = urlparser.urlparse(sasuri + '&comp=snapshot') temp_logger = temp_logger + str(datetime.datetime.utcnow()) + ' start calling the snapshot rest api. ' # initiate http call for blob-snapshot and get http response result, httpResp, errMsg, responseBody = http_util.HttpCallGetResponse('PUT', sasuri_obj, body_content, headers = headers, responseBodyRequired = True) temp_logger = temp_logger + str("responseBody: " + responseBody) if(result == CommonVariables.success and httpResp != None): # retrieve snapshot information from http response snapshot_info_indexer, snapshot_error, message = self.httpresponse_get_snapshot_info(httpResp, sasuri_index, sasuri, responseBody) temp_logger = temp_logger + str(datetime.datetime.utcnow()) + ' httpresponse_get_snapshot_info message: ' + str(message) else: # HttpCall failed error_logger = error_logger + str(datetime.datetime.utcnow()) + " snapshot HttpCallGetResponse failed " error_logger = error_logger + str(datetime.datetime.utcnow()) + str(errMsg) snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri end_time = datetime.datetime.utcnow() time_taken=end_time-start_time temp_logger = temp_logger + str(datetime.datetime.utcnow()) + ' time taken for snapshot ' + str(time_taken) except Exception as e: errorMsg = " Failed to do the snapshot with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) error_logger = error_logger + str(datetime.datetime.utcnow()) + errorMsg snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri temp_logger=temp_logger + str(datetime.datetime.utcnow()) + ' snapshot ends..' global_logger.put(temp_logger) global_error_logger.put(error_logger) snapshot_result_error.put(snapshot_error) snapshot_info_indexer_queue.put(snapshot_info_indexer) def snapshot_seq(self, sasuri, sasuri_index, settings, meta_data, disk_encryption_metadata = None): result = None snapshot_error = SnapshotError() snapshot_info_indexer = SnapshotInfoIndexerObj(sasuri_index, False, None, None) if(sasuri is None): self.logger.log("Failed to do the snapshot because sasuri is none",False,'Error') snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri try: sasuri_obj = urlparser.urlparse(sasuri) if(sasuri_obj is None or sasuri_obj.hostname is None): self.logger.log("Failed to parse the sasuri",False,'Error') snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri else: body_content = '' headers = {} headers["Content-Length"] = '0' if(meta_data is not None): for meta in meta_data: key = meta['Key'] value = meta['Value'] headers["x-ms-meta-" + key] = value if(disk_encryption_metadata is not None and len(disk_encryption_metadata) >= 2 and disk_encryption_metadata[0] and disk_encryption_metadata[1]): headers[disk_encryption_metadata[0]] = disk_encryption_metadata[1] self.logger.log("appending disk_encryption_details as part of headers while taking a snapshot") if(CommonVariables.isSnapshotTtlEnabled in settings and settings[CommonVariables.isSnapshotTtlEnabled]): self.logger.log("Not passing the TTL header via Guest path though it is enabled") http_util = HttpUtil(self.logger) sasuri_obj = urlparser.urlparse(sasuri + '&comp=snapshot') self.logger.log("start calling the snapshot rest api") # initiate http call for blob-snapshot and get http response result, httpResp, errMsg, responseBody = http_util.HttpCallGetResponse('PUT', sasuri_obj, body_content, headers = headers, responseBodyRequired = True) self.logger.log("responseBody: " + responseBody) if(result == CommonVariables.success and httpResp != None): # retrieve snapshot information from http response snapshot_info_indexer, snapshot_error, message = self.httpresponse_get_snapshot_info(httpResp, sasuri_index, sasuri, responseBody) self.logger.log(' httpresponse_get_snapshot_info message: ' + str(message)) else: # HttpCall failed self.logger.log(" snapshot HttpCallGetResponse failed ") self.logger.log(str(errMsg)) snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri except Exception as e: errorMsg = "Failed to do the snapshot with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg, False, 'Error') snapshot_error.errorcode = CommonVariables.error snapshot_error.sasuri = sasuri return snapshot_error, snapshot_info_indexer def snapshotall_parallel(self, paras, freezer, thaw_done, g_fsfreeze_on): self.logger.log("doing snapshotall now in parallel...") snapshot_result = SnapshotResult() blob_snapshot_info_array = [] all_failed = True exceptOccurred = False is_inconsistent = False thaw_done_local = thaw_done unable_to_sleep = False all_snapshots_failed = False set_next_backup_to_seq = False try: self.logger.log("before start of multiprocessing queues..") mp_jobs = [] queue_creation_starttime = datetime.datetime.now() global_logger = mp.Queue() global_error_logger = mp.Queue() snapshot_result_error = mp.Queue() snapshot_info_indexer_queue = mp.Queue() time_before_snapshot_start = datetime.datetime.utcnow() blobs = paras.blobs if blobs is not None: # initialize blob_snapshot_info_array mp_jobs = [] blob_index = 0 self.logger.log('****** 5. Snaphotting (Guest-parallel) Started') for blob in blobs: blobUri = blob.split("?")[0] self.logger.log("index: " + str(blob_index) + " blobUri: " + str(blobUri)) blob_snapshot_info_array.append(HostSnapshotObjects.BlobSnapshotInfo(False, blobUri, None, 500)) try: if(paras.isVMADEEnabled and len(paras.disk_encryption_details) > blob_index): mp_jobs.append(mp.Process(target=self.snapshot,args=(blob, blob_index, paras.wellKnownSettingFlags, paras.backup_metadata, snapshot_result_error, snapshot_info_indexer_queue, global_logger, global_error_logger, paras.disk_encryption_details[blob_index]))) else: mp_jobs.append(mp.Process(target=self.snapshot,args=(blob, blob_index, paras.wellKnownSettingFlags, paras.backup_metadata, snapshot_result_error, snapshot_info_indexer_queue, global_logger, global_error_logger))) except Exception as e: self.logger.log("multiprocess queue creation failed") all_snapshots_failed = True raise Exception("Exception while creating multiprocess queue") blob_index = blob_index + 1 counter = 0 for job in mp_jobs: job.start() if(counter == 0): queue_creation_endtime = datetime.datetime.now() timediff = queue_creation_endtime - queue_creation_starttime if(timediff.seconds >= 10): self.logger.log("mp queue creation took more than 10 secs. Setting next backup to sequential") set_next_backup_to_seq = True counter = counter + 1 for job in mp_jobs: job.join() self.logger.log('****** 6. Snaphotting (Guest-parallel) Completed') thaw_result = None if g_fsfreeze_on and thaw_done_local == False: time_before_thaw = datetime.datetime.now() thaw_result, unable_to_sleep = freezer.thaw_safe() time_after_thaw = datetime.datetime.now() HandlerUtil.HandlerUtility.add_to_telemetery_data("ThawTime", str(time_after_thaw-time_before_thaw)) thaw_done_local = True if(set_next_backup_to_seq == True): self.logger.log("Setting to sequential snapshot") self.hutil.set_value_to_configfile('seqsnapshot', '1') self.logger.log('T:S thaw result ' + str(thaw_result)) if(thaw_result is not None and len(thaw_result.errors) > 0 and (snapshot_result is None or len(snapshot_result.errors) == 0)): is_inconsistent = True snapshot_result.errors.append(thaw_result.errors) return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep, all_snapshots_failed self.logger.log('end of snapshot process') logging = [global_logger.get() for job in mp_jobs] self.logger.log(str(logging)) error_logging = [global_error_logger.get() for job in mp_jobs] self.logger.log(str(error_logging),False,'Error') if not snapshot_result_error.empty(): results = [snapshot_result_error.get() for job in mp_jobs] for result in results: if(result.errorcode != CommonVariables.success): snapshot_result.errors.append(result) if not snapshot_info_indexer_queue.empty(): snapshot_info_indexers = [snapshot_info_indexer_queue.get() for job in mp_jobs] for snapshot_info_indexer in snapshot_info_indexers: # update blob_snapshot_info_array element properties from snapshot_info_indexer object self.get_snapshot_info(snapshot_info_indexer, blob_snapshot_info_array[snapshot_info_indexer.index]) if (blob_snapshot_info_array[snapshot_info_indexer.index].isSuccessful == True): all_failed = False self.logger.log("index: " + str(snapshot_info_indexer.index) + " blobSnapshotUri: " + str(blob_snapshot_info_array[snapshot_info_indexer.index].snapshotUri)) all_snapshots_failed = all_failed self.logger.log("Setting all_snapshots_failed to " + str(all_snapshots_failed)) return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep, all_snapshots_failed else: self.logger.log("the blobs are None") return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep except Exception as e: errorMsg = " Unable to perform parallel snapshot with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg) exceptOccurred = True return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep, all_snapshots_failed def snapshotall_seq(self, paras, freezer, thaw_done, g_fsfreeze_on): exceptOccurred = False self.logger.log("doing snapshotall now in sequence...") snapshot_result = SnapshotResult() blob_snapshot_info_array = [] all_failed = True is_inconsistent = False thaw_done_local = thaw_done unable_to_sleep = False all_snapshots_failed = False try: blobs = paras.blobs if blobs is not None: blob_index = 0 self.logger.log('****** 5. Snaphotting (Guest-seq) Started') for blob in blobs: blobUri = blob.split("?")[0] self.logger.log("index: " + str(blob_index) + " blobUri: " + str(blobUri)) blob_snapshot_info_array.append(HostSnapshotObjects.BlobSnapshotInfo(False, blobUri, None, 500)) if(paras.isVMADEEnabled == True and len(paras.disk_encryption_details) > blob_index): snapshotError, snapshot_info_indexer = self.snapshot_seq(blob, blob_index, paras.wellKnownSettingFlags, paras.backup_metadata, paras.disk_encryption_details[blob_index]) else: snapshotError, snapshot_info_indexer = self.snapshot_seq(blob, blob_index, paras.wellKnownSettingFlags, paras.backup_metadata) if(snapshotError.errorcode != CommonVariables.success): snapshot_result.errors.append(snapshotError) # update blob_snapshot_info_array element properties from snapshot_info_indexer object self.get_snapshot_info(snapshot_info_indexer, blob_snapshot_info_array[blob_index]) if (blob_snapshot_info_array[blob_index].isSuccessful == True): all_failed = False blob_index = blob_index + 1 self.logger.log('****** 6. Snaphotting (Guest-seq) Completed') all_snapshots_failed = all_failed self.logger.log("Setting all_snapshots_failed to " + str(all_snapshots_failed)) thaw_result= None if g_fsfreeze_on and thaw_done_local== False: time_before_thaw = datetime.datetime.now() thaw_result, unable_to_sleep = freezer.thaw_safe() time_after_thaw = datetime.datetime.now() HandlerUtil.HandlerUtility.add_to_telemetery_data("ThawTime", str(time_after_thaw-time_before_thaw)) thaw_done_local = True self.logger.log('T:S thaw result ' + str(thaw_result)) if(thaw_result is not None and len(thaw_result.errors) > 0 and (snapshot_result is None or len(snapshot_result.errors) == 0)): snapshot_result.errors.append(thaw_result.errors) is_inconsistent= True return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep, all_snapshots_failed else: self.logger.log("the blobs are None") return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep except Exception as e: errorMsg = " Unable to perform sequential snapshot with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg) exceptOccurred = True return snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done_local, unable_to_sleep, all_snapshots_failed def snapshotall(self, paras, freezer, g_fsfreeze_on): thaw_done = False if (self.hutil.get_intvalue_from_configfile('seqsnapshot',0) == 1 or self.hutil.get_intvalue_from_configfile('seqsnapshot',0) == 2 or (len(paras.blobs) <= 4)): snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done, unable_to_sleep, all_snapshots_failed = self.snapshotall_seq(paras, freezer, thaw_done, g_fsfreeze_on) else: snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent, thaw_done, unable_to_sleep, all_snapshots_failed = self.snapshotall_parallel(paras, freezer, thaw_done, g_fsfreeze_on) self.logger.log("exceptOccurred : " + str(exceptOccurred) + " thaw_done : " + str(thaw_done) + " all_snapshots_failed : " + str(all_snapshots_failed)) if exceptOccurred and thaw_done == False and all_snapshots_failed: self.logger.log("Trying sequential snapshotting as parallel snapshotting failed") snapshot_result, blob_snapshot_info_array, all_failed, exceptOccurred, is_inconsistent,thaw_done, unable_to_sleep, all_snapshots_failed = self.snapshotall_seq(paras, freezer, thaw_done, g_fsfreeze_on) return snapshot_result, blob_snapshot_info_array, all_failed, is_inconsistent, unable_to_sleep, all_snapshots_failed def httpresponse_get_snapshot_info(self, resp, sasuri_index, sasuri, responseBody): snapshot_error = SnapshotError() snapshot_info_indexer = SnapshotInfoIndexerObj(sasuri_index, False, None, None) result = CommonVariables.error_http_failure message = "" if(resp != None): message = message + str(datetime.datetime.utcnow()) + " snapshot resp status: " + str(resp.status) + " " resp_headers = resp.getheaders() message = message + str(datetime.datetime.utcnow()) + " snapshot resp-header: " + str(resp_headers) + " " if(resp.status == 200 or resp.status == 201): result = CommonVariables.success snapshot_info_indexer.isSuccessful = True snapshot_info_indexer.snapshotTs = resp.getheader('x-ms-snapshot') else: result = resp.status snapshot_info_indexer.errorMessage = responseBody snapshot_info_indexer.statusCode = resp.status else: message = message + str(datetime.datetime.utcnow()) + " snapshot Http connection response is None" + " " message = message + str(datetime.datetime.utcnow()) + ' snapshot api returned: {0} '.format(result) + " " if(result != CommonVariables.success): snapshot_error.errorcode = result snapshot_error.sasuri = sasuri return snapshot_info_indexer, snapshot_error, message def get_snapshot_info(self, snapshot_info_indexer, snapshot_info): if (snapshot_info_indexer != None): self.logger.log("snapshot_info_indexer: " + str(snapshot_info_indexer)) snapshot_info.isSuccessful = snapshot_info_indexer.isSuccessful if (snapshot_info.isSuccessful == True): snapshot_info.snapshotUri = snapshot_info.snapshotUri + "?snapshot=" + str(snapshot_info_indexer.snapshotTs) else: snapshot_info.snapshotUri = None snapshot_info.errorMessage = snapshot_info_indexer.errorMessage snapshot_info.statusCode = snapshot_info_indexer.statusCode else: snapshot_info.isSuccessful = False snapshot_info.snapshotUri = None ================================================ FILE: VMBackup/main/handle.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import array import base64 import os import os.path import re import json import string import subprocess import sys import time import shlex import traceback import datetime import random try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers from threading import Thread from time import sleep from os.path import join from mounts import Mounts from mounts import Mount from patch import * from fsfreezer import FsFreezer from common import CommonVariables from parameterparser import ParameterParser from Utils import HandlerUtil from Utils.EventLoggerUtil import EventLogger from Utils import SizeCalculation from Utils import Status from freezesnapshotter import FreezeSnapshotter from backuplogger import Backuplogger from blobwriter import BlobWriter from taskidentity import TaskIdentity from MachineIdentity import MachineIdentity import ExtensionErrorCodeHelper from PluginHost import PluginHost from PluginHost import PluginHostResult import platform from workloadPatch import WorkloadPatch from signal import SIGTERM; #Main function is the only entrence to this extension handler def main(): global MyPatching,backup_logger,hutil,run_result,run_status,error_msg,freezer,freeze_result,snapshot_info_array,total_used_size,size_calculation_failed, patch_class_name, orig_distro, configSeqNo, eventlogger, disable_event_logging try: run_result = CommonVariables.success run_status = 'success' error_msg = '' freeze_result = None snapshot_info_array = None total_used_size = 0 size_calculation_failed = False eventlogger = None HandlerUtil.waagent.LoggerInit('/dev/console','/dev/stdout') hutil = HandlerUtil.HandlerUtility(HandlerUtil.waagent.Log, HandlerUtil.waagent.Error, CommonVariables.extension_name) backup_logger = Backuplogger(hutil) MyPatching, patch_class_name, orig_distro = GetMyPatching(backup_logger) hutil.patching = MyPatching configSeqNo = -1 hutil.try_parse_context(configSeqNo) disable_event_logging = hutil.get_intvalue_from_configfile("disable_logging", 0) use_async_event_logging = hutil.get_intvalue_from_configfile("async_event_logging ", 0) if disable_event_logging == 0 or hutil.event_dir is not None : eventlogger = EventLogger.GetInstance(backup_logger, hutil.event_dir, hutil.severity_level, use_async_event_logging) else: eventlogger = None hutil.set_event_logger(eventlogger) for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(daemon)", a): daemon() elif re.match("^([-/]*)(seqNo:)", a): try: configSeqNo = int(a.split(':')[1]) except: configSeqNo = -1 except Exception as e: if(eventlogger != None): eventlogger.dispose() sys.exit(0) def install(): global hutil,configSeqNo hutil.do_parse_context('Install', configSeqNo) hutil.do_exit(0, 'Install','success','0', 'Install Succeeded') def status_report_to_file(file_report_msg): global backup_logger,hutil hutil.write_to_status_file(file_report_msg) backup_logger.log("file status report message:",True) backup_logger.log(file_report_msg,True) def status_report_to_blob(blob_report_msg): global backup_logger,hutil,para_parser UploadStatusAndLog = hutil.get_strvalue_from_configfile('UploadStatusAndLog','True') if(UploadStatusAndLog == None or UploadStatusAndLog == 'True'): try: if(para_parser is not None and para_parser.statusBlobUri is not None and para_parser.statusBlobUri != ""): blobWriter = BlobWriter(hutil) if(blob_report_msg is not None): blobWriter.WriteBlob(blob_report_msg,para_parser.statusBlobUri) backup_logger.log("blob status report message:",True) backup_logger.log(blob_report_msg,True) else: backup_logger.log("blob_report_msg is none",True) except Exception as e: err_msg='cannot write status to the status blob'+traceback.format_exc() backup_logger.log(err_msg, True, 'Warning') def get_status_to_report(status, status_code, message, snapshot_info = None): global MyPatching,backup_logger,hutil,para_parser,total_used_size,size_calculation_failed blob_report_msg = None file_report_msg = None try: if total_used_size == -1 : sizeCalculation = SizeCalculation.SizeCalculation(patching = MyPatching , hutil = hutil, logger = backup_logger , para_parser = para_parser) total_used_size,size_calculation_failed = sizeCalculation.get_total_used_size() number_of_blobs = len(para_parser.includeLunList) maximum_possible_size = number_of_blobs * 1099511627776 if(total_used_size>maximum_possible_size and number_of_blobs != 0): total_used_size = maximum_possible_size backup_logger.log("Assertion Check, total size : {0} ,maximum_possible_size : {1}".format(total_used_size,maximum_possible_size),True) if(para_parser is not None): blob_report_msg, file_report_msg = hutil.do_status_report(operation='Enable',status=status,\ status_code=str(status_code),\ message=message,\ taskId=para_parser.taskId,\ commandStartTimeUTCTicks=para_parser.commandStartTimeUTCTicks,\ snapshot_info=snapshot_info,\ total_size = total_used_size,\ failure_flag = size_calculation_failed) except Exception as e: err_msg='cannot get status report parameters , Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(err_msg, True, 'Warning') return blob_report_msg, file_report_msg def exit_with_commit_log(status,result,error_msg, para_parser): global backup_logger backup_logger.log(error_msg, True, 'Error') if(para_parser is not None and para_parser.logsBlobUri is not None and para_parser.logsBlobUri != ""): backup_logger.commit(para_parser.logsBlobUri) blob_report_msg, file_report_msg = get_status_to_report(status, result, error_msg, None) status_report_to_file(file_report_msg) status_report_to_blob(blob_report_msg) if(eventlogger is not None): eventlogger.dispose() sys.exit(0) def exit_if_same_taskId(taskId): global backup_logger,hutil,para_parser trans_report_msg = None taskIdentity = TaskIdentity() last_taskId = taskIdentity.stored_identity() if(taskId == last_taskId): backup_logger.log("TaskId is same as last, so skip with Processed Status, current:" + str(taskId) + "== last:" + str(last_taskId), True) status=CommonVariables.status_success hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.SuccessAlreadyProcessedInput) status_code=CommonVariables.SuccessAlreadyProcessedInput message='TaskId AlreadyProcessed nothing to do' backup_logger.log(message, True) if(eventlogger is not None): eventlogger.dispose() sys.exit(0) def freeze_snapshot(timeout): try: global hutil,backup_logger,run_result,run_status,error_msg,freezer,freeze_result,para_parser,snapshot_info_array,g_fsfreeze_on, workload_patch canTakeCrashConsistentSnapshot = can_take_crash_consistent_snapshot(para_parser) freeze_snap_shotter = FreezeSnapshotter(backup_logger, hutil, freezer, g_fsfreeze_on, para_parser, canTakeCrashConsistentSnapshot) if (hutil.ExtErrorCode == ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList): temp_result = CommonVariables.FailedInvalidDataDiskLunList temp_status = 'error' error_msg = 'Invalid Input. IsAnyDiskExcluded is marked as true but input LUN list received from CRP is empty. '\ 'which is not allowed if VM has Direct Drives or if VM has Write Accelerated disks or if VM is a TVM/CVM.' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) backup_logger.log("Calling do snapshot method", True, 'Info') run_result, run_status, snapshot_info_array = freeze_snap_shotter.doFreezeSnapshot() if (canTakeCrashConsistentSnapshot == True and run_result != CommonVariables.success and run_result != CommonVariables.success_appconsistent): if (snapshot_info_array is not None and snapshot_info_array !=[] and check_snapshot_array_fail() == False and len(snapshot_info_array) == 1): run_status = CommonVariables.status_success run_result = CommonVariables.success hutil.SetSnapshotConsistencyType(Status.SnapshotConsistencyType.crashConsistent) except Exception as e: errMsg = 'Failed to do the snapshot with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') run_result = CommonVariables.error run_status = 'error' error_msg = 'Enable failed with exception in safe freeze or snapshot ' hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error) #snapshot_done = True def check_snapshot_array_fail(): global snapshot_info_array, backup_logger snapshot_array_fail = False if snapshot_info_array is not None and snapshot_info_array !=[]: for snapshot_index in range(len(snapshot_info_array)): if(snapshot_info_array[snapshot_index].isSuccessful == False): backup_logger.log('T:S snapshot failed at index ' + str(snapshot_index), True) snapshot_array_fail = True break return snapshot_array_fail def get_key_value(jsonObj, key): value = None if(key in jsonObj.keys()): value = jsonObj[key] return value def can_take_crash_consistent_snapshot(para_parser): global backup_logger takeCrashConsistentSnapshot = False if(para_parser != None and para_parser.customSettings != None and para_parser.customSettings != ''): customSettings = json.loads(para_parser.customSettings) isManagedVm = get_key_value(customSettings, 'isManagedVm') canTakeCrashConsistentSnapshot = get_key_value(customSettings, 'canTakeCrashConsistentSnapshot') backupRetryCount = get_key_value(customSettings, 'backupRetryCount') numberOfDisks = 0 if (para_parser.includeLunList is not None): numberOfDisks = len(para_parser.includeLunList) isAnyNone = (isManagedVm is None or canTakeCrashConsistentSnapshot is None or backupRetryCount is None) if (isAnyNone == False and isManagedVm == True and canTakeCrashConsistentSnapshot == True and backupRetryCount > 0 and numberOfDisks == 1): takeCrashConsistentSnapshot = True backup_logger.log("isManagedVm=" + str(isManagedVm) + ", canTakeCrashConsistentSnapshot=" + str(canTakeCrashConsistentSnapshot) + ", backupRetryCount=" + str(backupRetryCount) + ", numberOfDisks=" + str(numberOfDisks) + ", takeCrashConsistentSnapshot=" + str(takeCrashConsistentSnapshot), True, 'Info') return takeCrashConsistentSnapshot def spawn_monitor(location = "", strace_pid = 0): d = location if d == "": d = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) d = os.path.join(d, "debughelper") bd = os.path.join(d, "msft_snap_monit") try: args = [bd, "--wd", d] if (strace_pid > 0): args = [bd, "--wd", d, "--strace", "--tracepid", str(strace_pid)] backup_logger.log("[spawn_monitor] -> command: %s" % (" ".join(args))) p = subprocess.Popen(args) backup_logger.log("[spawn_monitor] -> monitoring started") return p except Exception as e: backup_logger.log("[spawn_monitor] -> subprocess Popen failed: %s" % (e)); return None def daemon(): global MyPatching, backup_logger, hutil, run_result, run_status, error_msg, freezer, para_parser, snapshot_done, snapshot_info_array, g_fsfreeze_on, total_used_size, patch_class_name, orig_distro, workload_patch, configSeqNo, eventlogger try: #this is using the most recent file timestamp. hutil.do_parse_context('Executing', configSeqNo) try: backup_logger.log('starting daemon initially', True, "Warning") backup_logger.log("patch_class_name: "+str(patch_class_name)+" and orig_distro: "+str(orig_distro),True) # handle the restoring scenario. mi = MachineIdentity() stored_identity = mi.stored_identity() if(stored_identity is None): mi.save_identity() else: current_identity = mi.current_identity() if(current_identity != stored_identity): current_seq_no = -1 backup_logger.log("machine identity not same, set current_seq_no to " + str(current_seq_no) + " " + str(stored_identity) + " " + str(current_identity), True) hutil.set_last_seq(current_seq_no) mi.save_identity() except Exception as e: errMsg = 'Failed to validate sequence number with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') freezer = FsFreezer(patching= MyPatching, logger = backup_logger, hutil = hutil) backup_logger.log("safeFreezeBinary exists " + str(freezer.file_exists), True, 'Info') global_error_result = None # precheck freeze_called = False configfile='/etc/azure/vmbackup.conf' thread_timeout=str(60) OnAppFailureDoFsFreeze = True OnAppSuccessDoFsFreeze = True MonitorRun = False MonitorEnableStrace = False MonitorLocation = "" #Adding python version to the telemetry try: python_version_info = sys.version_info python_version = str(sys.version_info[0])+ '.' + str(sys.version_info[1]) + '.' + str(sys.version_info[2]) HandlerUtil.HandlerUtility.add_to_telemetery_data("pythonVersion", python_version) except Exception as e: errMsg = 'Failed to do retrieve python version with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') #fetching platform architecture try: architecture = platform.architecture()[0] HandlerUtil.HandlerUtility.add_to_telemetery_data("platformArchitecture", architecture) except Exception as e: errMsg = 'Failed to do retrieve "platform architecture" with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') try: if(freezer.mounts is not None): hutil.partitioncount = len(freezer.mounts.mounts) backup_logger.log(" configfile " + str(configfile), True) config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_option('SnapshotThread','timeout'): thread_timeout= config.get('SnapshotThread','timeout') if config.has_option('SnapshotThread','OnAppFailureDoFsFreeze'): OnAppFailureDoFsFreeze= config.get('SnapshotThread','OnAppFailureDoFsFreeze') if config.has_option('SnapshotThread','OnAppSuccessDoFsFreeze'): OnAppSuccessDoFsFreeze= config.get('SnapshotThread','OnAppSuccessDoFsFreeze') if config.has_option("Monitor", "Run"): MonitorRun = config.getboolean("Monitor", "Run") if config.has_option("Monitor", "Strace"): MonitorEnableStrace = config.getboolean("Monitor", "Strace") if config.has_option("Monitor", "Location"): MonitorLocation = config.get("Monitor", "Location") except Exception as e: errMsg='cannot read config file or file not present' backup_logger.log(errMsg, True, 'Warning') backup_logger.log("final thread timeout" + thread_timeout, True) # Start the monitor process if enabled monitor_process = None if MonitorRun: if MonitorEnableStrace: monitor_process = spawn_monitor(location = MonitorLocation, strace_pid=os.getpid()) else: monitor_process = spawn_monitor(location = MonitorLocation) snapshot_info_array = None try: # we need to freeze the file system first backup_logger.log('starting daemon for freezing the file system', True) """ protectedSettings is the privateConfig passed from Powershell. WATCHOUT that, the _context_config are using the most freshest timestamp. if the time sync is alive, this should be right. """ protected_settings = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('protectedSettings', {}) public_settings = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('publicSettings') para_parser = ParameterParser(protected_settings, public_settings, backup_logger) hutil.update_settings_file() if(para_parser.taskId is not None and para_parser.taskId != "" and eventlogger is not None): eventlogger.update_properties(para_parser.taskId) if(bool(public_settings) == False and not protected_settings): error_msg = "unable to load certificate" hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedHandlerGuestAgentCertificateNotFound) temp_result=CommonVariables.FailedHandlerGuestAgentCertificateNotFound temp_status= 'error' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) if(freezer.file_exists == False): file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), freezer.safeFreezeFolderPath) error_msg = "safefreeze binary is missing in the following path " + str(file_path) hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedSafeFreezeBinaryNotFound) temp_result=CommonVariables.FailedSafeFreezeBinaryNotFound temp_status= 'error' backup_logger.log("exiting with commit",True,"Info") exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) if(para_parser.commandStartTimeUTCTicks is not None and para_parser.commandStartTimeUTCTicks != ""): canTakeCrashConsistentSnapshot = can_take_crash_consistent_snapshot(para_parser) temp_g_fsfreeze_on = True freeze_snap_shotter = FreezeSnapshotter(backup_logger, hutil, freezer, temp_g_fsfreeze_on, para_parser, canTakeCrashConsistentSnapshot) if (hutil.ExtErrorCode == ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedInvalidDataDiskLunList): temp_result = CommonVariables.FailedInvalidDataDiskLunList temp_status = 'error' error_msg = 'Invalid Input. IsAnyDiskExcluded is marked as true but input LUN list received from CRP is empty. '\ 'which is not allowed if VM has Direct Drives or if VM has Write Accelerated disks or if VM is a TVM/CVM.' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) if freeze_snap_shotter.is_command_timedout(para_parser) : error_msg = "CRP timeout limit has reached, will not take snapshot." errMsg = error_msg hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.FailedGuestAgentInvokedCommandTooLate) temp_result=CommonVariables.FailedGuestAgentInvokedCommandTooLate temp_status= 'error' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) hutil.save_seq() commandToExecute = para_parser.commandToExecute #validate all the required parameter here backup_logger.log('The command '+ commandToExecute+ ' is being validated',True) if(CommonVariables.iaas_install_command in commandToExecute.lower()): backup_logger.log('install succeed.',True) run_status = 'success' error_msg = 'Install Succeeded' run_result = CommonVariables.success backup_logger.log(error_msg) elif(CommonVariables.iaas_vmbackup_command in commandToExecute.lower()): if(para_parser.backup_metadata is None or para_parser.public_config_obj is None): run_result = CommonVariables.error_parameter hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error_parameter) run_status = 'error' error_msg = 'required field empty or not correct' backup_logger.log(error_msg, True, 'Error') else: backup_logger.log('commandToExecute for backup is ' + commandToExecute, True) """ make sure the log is not doing when the file system is freezed. """ temp_status= 'success' temp_result=CommonVariables.ExtensionTempTerminalState temp_msg='Transitioning state in extension' blob_report_msg, file_report_msg = get_status_to_report(temp_status, temp_result, temp_msg, None) status_report_to_file(file_report_msg) status_report_to_blob(blob_report_msg) #partial logging before freeze if(para_parser is not None and para_parser.logsBlobUri is not None and para_parser.logsBlobUri != ""): backup_logger.commit_to_blob(para_parser.logsBlobUri) else: backup_logger.log("the logs blob uri is not there, so do not upload log.") backup_logger.log('commandToExecute after commiting the blob is ' + commandToExecute, True) workload_patch = WorkloadPatch.WorkloadPatch(backup_logger) #new flow only if workload name is present in workload.conf if workload_patch.name != None and workload_patch.name != "": backup_logger.log("workload backup enabled for workload: " + workload_patch.name, True) hutil.set_pre_post_enabled() pre_skipped = False if len(workload_patch.error_details) > 0: backup_logger.log("skip pre and post") pre_skipped = True else: workload_patch.pre() if len(workload_patch.error_details) > 0: backup_logger.log("file system consistent backup only") #todo error handling if len(workload_patch.error_details) > 0 and OnAppFailureDoFsFreeze == True: #App&FS consistency g_fsfreeze_on = True elif len(workload_patch.error_details) > 0 and OnAppFailureDoFsFreeze == False: # Do Fs freeze only if App success hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error) error_msg= 'Failing backup as OnAppFailureDoFsFreeze is set to false' temp_result=CommonVariables.error temp_status= 'error' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) elif len(workload_patch.error_details) == 0 and OnAppSuccessDoFsFreeze == False: # App only g_fsfreeze_on = False elif len(workload_patch.error_details) == 0 and OnAppSuccessDoFsFreeze == True: #App&FS consistency g_fsfreeze_on = True else: g_fsfreeze_on = True freeze_snapshot(thread_timeout) if pre_skipped == False: workload_patch.post() workload_error = workload_patch.populateErrors() if workload_error != None and g_fsfreeze_on == False: run_status = 'error' run_result = workload_error.errorCode hutil.SetExtErrorCode(workload_error.errorCode) error_msg = 'Workload Patch failed with error message: ' + workload_error.errorMsg error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(hutil.ExtErrorCode) backup_logger.log(error_msg, True) elif workload_error != None and g_fsfreeze_on == True: hutil.SetExtErrorCode(workload_error.errorCode) error_msg = 'Workload Patch failed with warning message: ' + workload_error.errorMsg error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(hutil.ExtErrorCode) backup_logger.log(error_msg, True) else: if(run_status == CommonVariables.status_success): run_status = 'success' run_result = CommonVariables.success_appconsistent hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.success_appconsistent) error_msg = 'Enable Succeeded with App Consistent Snapshot' backup_logger.log(error_msg, True) else: error_msg = 'Enable failed in fsfreeze snapshot flow' backup_logger.log(error_msg, True) else: PluginHostObj = PluginHost(logger=backup_logger) PluginHostErrorCode,dobackup,g_fsfreeze_on = PluginHostObj.pre_check() doFsConsistentbackup = False appconsistentBackup = False if not (PluginHostErrorCode == CommonVariables.FailedPrepostPluginhostConfigParsing or PluginHostErrorCode == CommonVariables.FailedPrepostPluginConfigParsing or PluginHostErrorCode == CommonVariables.FailedPrepostPluginhostConfigNotFound or PluginHostErrorCode == CommonVariables.FailedPrepostPluginhostConfigPermissionError or PluginHostErrorCode == CommonVariables.FailedPrepostPluginConfigNotFound): backup_logger.log('App Consistent Consistent Backup Enabled', True) HandlerUtil.HandlerUtility.add_to_telemetery_data("isPrePostEnabled", "true") appconsistentBackup = True if(PluginHostErrorCode != CommonVariables.PrePost_PluginStatus_Success): backup_logger.log('Triggering File System Consistent Backup because of error code' + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(PluginHostErrorCode), True) doFsConsistentbackup = True preResult = PluginHostResult() postResult = PluginHostResult() if not doFsConsistentbackup: preResult = PluginHostObj.pre_script() dobackup = preResult.continueBackup if(g_fsfreeze_on == False and preResult.anyScriptFailed): dobackup = False if dobackup: freeze_snapshot(thread_timeout) if not doFsConsistentbackup: postResult = PluginHostObj.post_script() if not postResult.continueBackup: dobackup = False if(g_fsfreeze_on == False and postResult.anyScriptFailed): dobackup = False if not dobackup: if run_result == CommonVariables.success and PluginHostErrorCode != CommonVariables.PrePost_PluginStatus_Success: run_status = 'error' run_result = PluginHostErrorCode hutil.SetExtErrorCode(PluginHostErrorCode) error_msg = 'Plugin Host Precheck Failed' error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(hutil.ExtErrorCode) backup_logger.log(error_msg, True) if run_result == CommonVariables.success: pre_plugin_errors = preResult.errors for error in pre_plugin_errors: if error.errorCode != CommonVariables.PrePost_PluginStatus_Success: run_status = 'error' run_result = error.errorCode hutil.SetExtErrorCode(error.errorCode) error_msg = 'PreScript failed for the plugin ' + error.pluginName error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(hutil.ExtErrorCode) backup_logger.log(error_msg, True) break if run_result == CommonVariables.success: post_plugin_errors = postResult.errors for error in post_plugin_errors: if error.errorCode != CommonVariables.PrePost_PluginStatus_Success: run_status = 'error' run_result = error.errorCode hutil.SetExtErrorCode(error.errorCode) error_msg = 'PostScript failed for the plugin ' + error.pluginName error_msg = error_msg + ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.StatusCodeStringBuilder(hutil.ExtErrorCode) backup_logger.log(error_msg, True) break if appconsistentBackup: if(PluginHostErrorCode != CommonVariables.PrePost_PluginStatus_Success): hutil.SetExtErrorCode(PluginHostErrorCode) pre_plugin_errors = preResult.errors for error in pre_plugin_errors: if error.errorCode != CommonVariables.PrePost_PluginStatus_Success: hutil.SetExtErrorCode(error.errorCode) post_plugin_errors = postResult.errors for error in post_plugin_errors: if error.errorCode != CommonVariables.PrePost_PluginStatus_Success: hutil.SetExtErrorCode(error.errorCode) if run_result == CommonVariables.success and not doFsConsistentbackup and not (preResult.anyScriptFailed or postResult.anyScriptFailed): run_status = 'success' run_result = CommonVariables.success_appconsistent hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.success_appconsistent) error_msg = 'Enable Succeeded with App Consistent Snapshot' backup_logger.log(error_msg, True) else: run_status = 'error' run_result = CommonVariables.error_parameter hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error_parameter) error_msg = 'command is not correct' backup_logger.log(error_msg, True, 'Error') except Exception as e: hutil.update_settings_file() errMsg = 'Failed to enable the extension with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') global_error_result = e if monitor_process is not None: monitor_process.terminate() """ we do the final report here to get rid of the complex logic to handle the logging when file system be freezed issue. """ try: if(global_error_result is not None): if(hasattr(global_error_result,'errno') and global_error_result.errno == 2): run_result = CommonVariables.error_12 hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error_12) elif(para_parser is None): run_result = CommonVariables.error_parameter hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error_parameter) else: run_result = CommonVariables.error hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error) run_status = 'error' error_msg += ('Enable failed.' + str(global_error_result)) status_report_msg = None hutil.SetExtErrorCode(run_result) #setting extension errorcode at the end if missed somewhere HandlerUtil.HandlerUtility.add_to_telemetery_data("extErrorCode", str(ExtensionErrorCodeHelper.ExtensionErrorCodeHelper.ExtensionErrorCodeNameDict[hutil.ExtErrorCode])) total_used_size = -1 blob_report_msg, file_report_msg = get_status_to_report(run_status,run_result,error_msg, snapshot_info_array) if(hutil.is_status_file_exists()): status_report_to_file(file_report_msg) status_report_to_blob(blob_report_msg) except Exception as e: errMsg = 'Failed to log status in extension' errMsg += str(e) backup_logger.log(errMsg, True, 'Error') if(para_parser is not None and para_parser.logsBlobUri is not None and para_parser.logsBlobUri != ""): backup_logger.commit(para_parser.logsBlobUri) else: backup_logger.log("the logs blob uri is not there, so do not upload log.") backup_logger.commit_to_local() if(eventlogger is not None): eventlogger.dispose() except Exception as e: backup_logger.log(str(e), True, 'Error') if(eventlogger is not None): eventlogger.dispose() if monitor_process is not None: monitor_process.terminate() sys.exit(0) def uninstall(): global configSeqNo hutil.do_parse_context('Uninstall', configSeqNo) hutil.do_exit(0,'Uninstall','success','0', 'Uninstall succeeded') def disable(): global configSeqNo hutil.do_parse_context('Disable', configSeqNo) hutil.do_exit(0,'Disable','success','0', 'Disable Succeeded') def update(): global configSeqNo hutil.do_parse_context('Update', configSeqNo) hutil.do_exit(0,'Update','success','0', 'Update Succeeded') def enable(): global backup_logger,hutil,error_msg,para_parser,patch_class_name,orig_distro,configSeqNo,eventlogger,disable_event_logging try: hutil.do_parse_context('Enable', configSeqNo) backup_logger.log('starting enable', True) backup_logger.log("patch_class_name: "+str(patch_class_name)+" and orig_distro: "+str(orig_distro),True) if(disable_event_logging != 0): backup_logger.log("logging via guest agent is turned off") hutil.exit_if_same_seq() hutil.save_seq() protected_settings = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('protectedSettings', {}) public_settings = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('publicSettings') para_parser = ParameterParser(protected_settings, public_settings, backup_logger) try: if CommonVariables.enableSnapshotExtensionPolling in para_parser.wellKnownSettingFlags and para_parser.wellKnownSettingFlags[CommonVariables.enableSnapshotExtensionPolling]: create_host_based_service() except Exception as e: backup_logger.log("error starting new host based daemon: {}".format(e), True, "Error") if(para_parser.taskId is not None and para_parser.taskId != ""): if(eventlogger is not None): eventlogger.update_properties(para_parser.taskId) backup_logger.log('taskId: ' + str(para_parser.taskId), True) randomSleepTime = random.randint(500, 5000) backup_logger.log('Sleeping for milliseconds: ' + str(randomSleepTime), True) time.sleep(randomSleepTime / 1000) exit_if_same_taskId(para_parser.taskId) taskIdentity = TaskIdentity() taskIdentity.save_identity(para_parser.taskId) temp_status= 'success' temp_result=CommonVariables.ExtensionTempTerminalState temp_msg='Transitioning state in extension' blob_report_msg, file_report_msg = get_status_to_report(temp_status, temp_result, temp_msg, None) status_report_to_file(file_report_msg) if(eventlogger is not None): eventlogger.dispose() start_daemon() sys.exit(0) except Exception as e: hutil.update_settings_file() errMsg = 'Failed to call the daemon with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) backup_logger.log(errMsg, True, 'Error') global_error_result = e temp_status= 'error' temp_result=CommonVariables.error hutil.SetExtErrorCode(ExtensionErrorCodeHelper.ExtensionErrorCodeEnum.error) error_msg = 'Failed to call the daemon' exit_with_commit_log(temp_status, temp_result,error_msg, para_parser) def thread_for_log_upload(): global para_parser,backup_logger backup_logger.commit(para_parser.logsBlobUri) def start_daemon(): args = [os.path.join(os.getcwd(), "main/handle.sh"), "daemon"] #This process will start a new background process by calling # handle.py -daemon #to run the script and will exit itself immediatelly. #Redirect stdout and stderr to /dev/null. Otherwise daemon process will #throw Broke pipe exeception when parent process exit. devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) def can_use_systemd(): try: pso = subprocess.check_output(["systemctl", "is-system-running"]) return pso[0:7].decode("utf-8") == "running" except Exception as e: backup_logger.log("error running `systemctl is-system-running`: {}".format(e), True, 'Warning') try: pso = subprocess.check_output(["ps", "--no-headers", "-o", "comm", "1"]) return pso[0:7].decode("utf-8") == "systemd" except Exception as e: backup_logger.log("error running `ps --no-headers -o comm 1`: {}".format(e), True, "Warning") return False def create_host_based_systemd_service(): ## Create the file `/etc/systemd/system/directsnapshot.service` ## [Unit] ## Description=My test service ## After=multi-user.target ## [Service] ## Type=simple ## Restart=always ## ExecStart=/usr/bin/python3 /home//test.py ## [Install] ## WantedBy=multi-user.target systemd_service_file = "/etc/systemd/system/directsnapshot.service" script_dir = os.path.dirname(os.path.realpath(__file__)) work_dir = os.path.dirname(script_dir) script_path = os.path.join(script_dir, "handle_host_daemon.py") sys_script_path = os.path.join("main", "handle_host_daemon.py") exec_path = "" try: exec_path = sys.executable except Exception as e: backup_logger.log("error fetching python executable path: {}".format(e), True, "Error") return if exec_path == "" or exec_path is None: backup_logger.log("empty python executable path", True, "Error") return if os.path.isfile(systemd_service_file): try: subprocess.check_output(["systemctl", "stop", "directsnapshot.service"]) os.remove(systemd_service_file) except Exception as e: backup_logger.log("error removing existing systemd service: {}".format(e), True, "Error") return with open(systemd_service_file, "w", encoding="utf-8") as f: f.write("[Unit]\n") f.write("\tDescription=Snapshot service for Microsoft Azure Restore Points\n") f.write("\tAfter=multi-user.target\n") f.write("[Service]\n") f.write("\tType=simple\n") f.write("\tRestart=always\n") f.write("\tWorkingDirectory={}\n".format(work_dir)) f.write("\tExecStart={} {}\n".format(exec_path, sys_script_path)) f.write("[Install]\n") f.write("\tWantedBy=multi-user.target\n") # Check if pid file exists pidfile=os.path.join(work_dir, "directsnapshot.pid") if os.path.isfile(pidfile): try: opid = None with open(pidfile, "r", encoding="utf-8") as f: opid = f.read() if opid is not None and os.path.isdir(os.path.join("/proc", opid)): backup_logger.log("process exists. killing", True, "Warning") subprocess.check_output(["kill", "-9", opid]) backup_logger.log("process killed") except Exception as e: backup_logger.log("error checking for and killing daemon process: {}".format(e), True, "Error") # Daemon reload, enable and run try: subprocess.check_output(["systemctl", "daemon-reload"]) subprocess.check_output(["systemctl", "enable", "--now", "directsnapshot.service"]) except Exception as e: backup_logger.log("error running systemd service: {}".format(e), True, "Error") def create_host_based_process(): script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) subprocess.Popen( ["./main/handle_host_daemon.sh"], cwd = script_dir, shell = True ) def create_host_based_service(): try: if can_use_systemd(): create_host_based_systemd_service() else: create_host_based_process() except Exception as e: backup_logger.log("error creating service for host based snapshots: {}".format(e), True, "Error") if __name__ == '__main__' : main() ================================================ FILE: VMBackup/main/handle.sh ================================================ #!/usr/bin/env sh pwdcommand=`pwd` pwdstr="$pwdcommand" output=`cat $pwdstr'/HandlerEnvironment.json'` outputstr="$output" poststr=${outputstr#*logFolder\"} postsubstr=${poststr#*\"} postsubstr1=${postsubstr#*\"} resultstrlen=`expr ${#postsubstr} - 1 - ${#postsubstr1}` logfolder=$(echo $postsubstr | cut -b 1-$resultstrlen) logfile=$logfolder'/shell.log' rc=3 arc=0 if [ "$1" = "install" ] then if [ -f "/etc/azure/workload.conf" ] then WorkloadConfEdited=`awk '/(workload_name)([ ]*[=])([ ]*[(^|\")a-zA-Z(^|\")])/' /etc/azure/workload.conf` if [ "$WorkloadConfEdited" != "" ] then #Workload.conf is edited echo "`date -u`- The command is $1, exiting without conf file copy" >> $logfile else #workload.conf is not edited cp main/workloadPatch/WorkloadUtils/workload.conf /etc/azure/workload.conf echo "`date -u`- The command is $1, exiting with conf file copy" >> $logfile fi exit $arc else mkdir -p /etc/azure cp main/workloadPatch/WorkloadUtils/workload.conf /etc/azure/workload.conf echo "`date -u`- The command is $1, exiting with conf file copy" >> $logfile exit $arc fi elif [ "$1" != "enable" ] && [ "$1" != "daemon" ] then echo "`date -u`- The command is $1, exiting" >> $logfile exit $arc fi configSeqNo="$(echo `printenv ConfigSequenceNumber`)" if [ -z ${configSeqNo} ] then configSeqNo='seqNo:-1' echo "`date -u`- ConfigSequenceNumber not found in environment variable ${configSeqNo}" >> $logfile else configSeqNo='seqNo:'$configSeqNo echo "`date -u`- ConfigSequenceNumber from environment variable ${configSeqNo}" >> $logfile fi pythonVersionList="python3.8 python3.7 python3.6 python3.5 python3.4 python3.3 python3 python2.7 python2.6 python2 python" for pythonVersion in ${pythonVersionList}; do cmnd="/usr/bin/${pythonVersion}" if [ ! -f "${cmnd}" ] then cmnd="/usr/local/bin/${pythonVersion}" fi if [ -f "${cmnd}" ] then echo "`date -u`- ${pythonVersion} path exists" >> $logfile $cmnd main/handle.py -$configSeqNo -$1 rc=$? fi if [ $rc -eq 0 ] then break fi done pythonProcess=$(ps -ef | grep waagent | grep python) pythonPath=$(echo "${pythonProcess}" | head -n1 | awk '{print $8;}') if [ $rc -ne 0 ] && [ -f "`which python`" ] then echo "`date -u`- python path exists" >> $logfile /usr/bin/env python main/handle.py -$configSeqNo -$1 rc=$? fi if [ $rc -ne 0 ] && [ -f "${pythonPath}" ] then echo "`date -u`- python path exists" >> $logfile $pythonPath main/handle.py -$configSeqNo -$1 rc=$? fi if [ $rc -eq 3 ] then echo "`date -u`- python version unknown" >> $logfile fi echo "`date -u`- $rc returned from handle.py" >> $logfile exit $rc ================================================ FILE: VMBackup/main/handle_host_daemon.py ================================================ #!/usr/bin/env python import time import os import threading import signal import sys import json from Utils.WAAgentUtil import waagent from Utils import HandlerUtil import datetime from common import CommonVariables import subprocess import traceback from datetime import datetime IS_PYTHON3 = sys.version_info[0] == 3 if IS_PYTHON3: import configparser as ConfigParsers else: import ConfigParser as ConfigParsers if IS_PYTHON3: from urllib import request else: import urllib2 as request if IS_PYTHON3: from urllib.error import HTTPError else: from urllib2 import HTTPError if IS_PYTHON3: import urllib.parse as urllib else: import urllib SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) BASE_URI="http://168.63.129.16" STORAGE_DEVICE_PATH = '/sys/bus/vmbus/devices/' GEN2_DEVICE_ID = 'f8b3781a-1e82-4818-a1c3-63d806ec15bb' # LOCK_FILE_DIR="/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock" # LOCK_FILE="/etc/azure/MicrosoftRecoverySvcsSafeFreezeLock/SafeFreezeLockFile" # LOCK_FILE_NAME="SafeFreezeLockFile" SNAPSHOT_INPROGRESS = False class HandlerContext: def __init__(self,name): self._name = name self._version = '0.0' return def read_file(filepath): """ Read and return contents of 'filepath'. """ mode = 'rb' with open(filepath, mode) as in_file: data = in_file.read().decode('utf-8') return data class Handler: _log = None _error = None def __init__(self, log, error, short_name): self._context = HandlerContext(short_name) self._log = log self._error = error self.eventlogger = None self.log_message = "" handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("[handle_host_daemon.py] -> Unable to locate " + handler_env_file) return None ctxt = waagent.GetFileContents(handler_env_file) if ctxt == None : self.error("[handle_host_daemon] -> Unable to read " + handler_env_file) try: handler_env = json.loads(ctxt) except: pass if handler_env == None : self.log("JSON error processing " + handler_env_file) return None if type(handler_env) == list: handler_env = handler_env[0] self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context.log_dir = handler_env['handlerEnvironment']['logFolder'] self._context.log_file = os.path.join(self._context.log_dir,'host_based_extension.log') self.logging_file=self._context.log_file def _get_log_prefix(self): return '[%s-%s]' % (self._context._name, self._context._version) def get_value_from_configfile(self, key): value = None configfile = '/etc/azure/vmbackup.conf' try : if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_option('SnapshotThread',key): value = config.get('SnapshotThread',key) except Exception as e: pass return value def get_strvalue_from_configfile(self, key, default): value = self.get_value_from_configfile(key) if value == None or value == '': value = default try : value_str = str(value) except ValueError : self.log('Not able to parse the read value as string, falling back to default value', 'Warning') value = default return value def get_intvalue_from_configfile(self, key, default): value = default value = self.get_value_from_configfile(key) if value == None or value == '': value = default try : value_int = int(value) except ValueError : self.log('Not able to parse the read value as int, falling back to default value', 'Warning') value = default return int(value) def log(self, message, level='Info'): print("[Handler.log] -> Level: {} -> {}".format(level, message)) try: self.log_with_no_try_except(message, level) except IOError: pass except Exception as e: try: errMsg = str(e) + 'Exception in hutil.log' self.log_with_no_try_except(errMsg, 'Warning') except Exception as e: pass def log_with_no_try_except(self, message, level='Info'): WriteLog = self.get_strvalue_from_configfile('WriteLog','True') if (WriteLog == None or WriteLog == 'True'): if sys.version_info > (3,): if self.logging_file is not None: self.log_py3(message) if self.eventlogger != None: self.eventlogger.trace_message(level, message) else: pass else: self._log(self._get_log_prefix() + message) if self.eventlogger != None: self.eventlogger.trace_message(level, message) message = "{0} {1} {2} \n".format(str(datetime.datetime.utcnow()) , level , message) self.log_message = self.log_message + message def log_py3(self, msg): if type(msg) is not str: msg = str(msg, errors="backslashreplace") msg = str(datetime.datetime.utcnow()) + " " + str(self._get_log_prefix()) + msg + "\n" try: with open(self.logging_file, "a+") as C : C.write(msg) except IOError: pass def error(self, message): self._error(self._get_log_prefix() + message) class InvalidSnapshotRequestInitError(Exception): def __init__(): super().__init__("Snapshot request object intialized incorrectly") # class AcquireSnapshotLockError(Exception): # def __init__(): # super().__init__("Failed to acquire snapshot lock") class GetMountsError(Exception): def __init__(message = ""): super().__init__("[SnapshotRequest.get_mounts] -> failed: {}".format(message)) def print_from_thread(msg): os.write(sys.stdout.fileno(), msg.encode("utf-8")) def thread_for_binary(self,args): print_from_thread("[FreezeHandler.thread_for_binary] -> Thread for binary is called: {}\n".format(args)) time.sleep(3) print_from_thread("[FreezeHandler.thread_for_binary] -> Waited in thread for 3 seconds\n") print_from_thread("[FreezeHandler.thread_for_binary] -> ****** 1. Starting Freeze Binary \n") self.child = subprocess.Popen(args,stdout=subprocess.PIPE) print_from_thread("Binary subprocess Created\n") class FreezeHandler(object): def __init__(self,handler): # sig_handle valid values(0:nothing done,1: freezed successfully, 2:freeze failed) self.sig_handle = 0 self.child = None self.handler = handler def sigusr1_handler(self, signal, frame): print_from_thread('[FreezeHandler.sigusr1_handler] -> freezed\n') print_from_thread("[FreezeHandler.sigusr1_handler] -> ****** 4. Freeze Completed (Signal=1 received)\n") self.sig_handle=1 def sigchld_handler(self, signal, frame): print_from_thread('[FreezeHandler.sigchld_handler] -> some child process terminated\n') if(self.child is not None and self.child.poll() is not None): print_from_thread("[FreezeHandler.sigchld_handler] -> binary child terminated\n") print_from_thread("[FreezeHandler.sigchld_handler] -> ****** 9. Binary Process completed (Signal=2 received)\n") self.sig_handle=2 def reset_signals(self): self.sig_handle = 0 self.child = None def startproc(self,args): binary_thread = threading.Thread(target=thread_for_binary, args=[self, args]) binary_thread.start() SafeFreezeWaitInSecondsDefault = 66 proc_sleep_time = self.handler.get_intvalue_from_configfile('SafeFreezeWaitInSeconds',SafeFreezeWaitInSecondsDefault) for i in range(0,(int(proc_sleep_time/2))): if(self.sig_handle==0): print("[FreezeHandler.startproc] -> inside loop with sig_handle "+str(self.sig_handle)) time.sleep(2) else: break print("[FreezeHandler.startproc] -> Binary output for signal handled: "+str(self.sig_handle)) return self.sig_handle def signal_receiver(self): signal.signal(signal.SIGUSR1,self.sigusr1_handler) signal.signal(signal.SIGCHLD,self.sigchld_handler) class SnapshotRequest: def __init__(self, handler, data): global SNAPSHOT_INPROGRESS, BASE_URI, GEN2_DEVICE_ID self.freeze_handler = FreezeHandler(handler) self.freeze_start = datetime.utcnow() self.freeze_safe_active = False if isinstance(handler, Handler): self.handler = handler # MY_PATCHING, PATCH_CLASS_NAME, ORIG_DISTRO = GetMyPatching(handler) else: raise InvalidSnapshotRequestInitError if "snapshotId" in data and isinstance(data["snapshotId"], str): self.snapshotId = data["snapshotId"] else: raise InvalidSnapshotRequestInitError if "luns" in data and isinstance(data["luns"], list): self.luns = data["luns"] # else: # raise InvalidSnapshotRequestInitError if "extensionSettings" in data and isinstance(data["extensionSettigns"], dict): self.extensionSettings = {} es = data["extensionSettings"] if "public" in es and isinstance(es["public"], dict): self.extensionSettings["public"] = es["public"] else: self.extensionSettings["public"] = {} if "protected" in es and isinstance(es["protected"], dict): self.extensionSettings["protected"] = {} pro = es["protected"] if "loggingBlobSasUri" in pro and isinstance(pro["loggingBlobSasUri"], str): self.extensionSettings.protected["loggingBlobSasUri"] = pro["loggingBlobSasUri"] # else: # raise InvalidSnapshotRequestInitError if "statusBlobSasUri" in pro and isinstance(pro["statusBlobSasUri"], str): self.extensionSettings.protected["statusBlobSasUri"] = pro["statusBlobSasUri"] # else: # raise InvalidSnapshotRequestInitError # else: # raise InvalidSnapshotRequestInitError if "ProtectedSettingsCertThumbprint" in data and isinstance(data["ProtectedSettingsCertThumbprint"], str): self.ProtectedSettingsCertThumbprint = data["ProtectedSettingsCertThumbprint"] # else: # raise InvalidSnapshotRequestInitError self.__data = data # def acquire_snapshot_lock(self): # try: # if not os.path.isdir('/etc/azure'): # os.mkdir('/etc/azure') # if not os.path.isdir(LOCK_FILE_DIR): # if not os.path.isfile(LOCK_FILE_DIR): # os.mkdir(LOCK_FILE_DIR) # else: # os.remove(LOCK_FILE_DIR) # os.mkdir(LOCK_FILE_DIR) # self.safeFreezelockFile = open(LOCK_FILE,"w") # try: # fcntl.lockf(self.safeFreezelockFile, fcntl.LOCK_EX | fcntl.LOCK_NB) # self.isAcquiredLock = True # return True # except Exception as e: # self.handler.log("[lock_snapshot_file] -> fcntl.lockf has failed: ", e) # self.safeFreezelockFile.close() # except Exception as e: # self.handler.log("[lock_snapshot_file] -> Unexpected error occured: ", e) # return False # def release_snapshot_lock(self): # try: # if (self.isAquireLock == True): # try: # fcntl.lockf(self.safeFreezelockFile, fcntl.LOCK_UN) # self.safeFreezelockFile.close() # except Exception as e: # self.handler.log("Failed to unlock: %s, stack trace: %s" % (str(e), traceback.format_exc()),True) # try: # os.remove(LOCK_FILE) # except Exception as e: # self.handler.log( # "Failed to delete %s file:\nException:\n%s\nStack Trace:\n%s" % # LOCK_FILE, str(e), traceback.format_exc()) # except Exception as e: # self.handler.log("[release_snapshot_lock] -> unexpected error occurred: ", e) # return False # Ignores usb devices # TODO: suppport lvm setup def get_block_devices(self): p1 = subprocess.Popen(["lsblk", "-dnl", "-o", "NAME"], stdout=subprocess.PIPE) p2 = subprocess.check_output(["grep", "-E", "(sd|nvme)"], stdin=p1.stdout).decode("utf-8") p1.stdout.close() disks = [] for x in p2.split("\n"): # print("device: {}".format(x)) if not x.strip(): continue if not self.is_usb("/dev/{}".format(x)): disks.append(x) return disks def is_usb(self, device): # lsblk -dnl -o NAME | grep 'sd' # udevadm info /dev/sda --query=property | grep ID_BUS p1 = subprocess.Popen(["udevadm", "info", device, "--query=property"], stdout=subprocess.PIPE) p2 = subprocess.check_output(["grep", 'ID_BUS'], stdin=p1.stdout).decode("utf-8") p1.stdout.close() return p2.endswith("=usb") @staticmethod def _enumerate_device_id(): """ Enumerate all storage device IDs. Args: None Returns: Iterator[Tuple[str, str]]: VmBus and storage devices. """ if os.path.exists(STORAGE_DEVICE_PATH): for vmbus in os.listdir(STORAGE_DEVICE_PATH): deviceid = read_file(filepath=os.path.join(STORAGE_DEVICE_PATH, vmbus, "device_id")) guid = deviceid.strip('{}\n') yield vmbus, guid @staticmethod def search_for_resource_disk(gen1_device_prefix, gen2_device_id): """ Search the filesystem for a device by ID or prefix. Args: gen1_device_prefix (str): Gen1 resource disk prefix. gen2_device_id (str): Gen2 resource device ID. Returns: str: The found device. """ device = None # We have to try device IDs for both Gen1 and Gen2 VMs. #ResourceDiskUtil.logger.log('Searching gen1 prefix {0} or gen2 {1}'.format(gen1_device_prefix, gen2_device_id),True) try: # pylint: disable=R1702 for vmbus, guid in SnapshotRequest._enumerate_device_id(): if guid.startswith(gen1_device_prefix) or guid == gen2_device_id: for root, dirs, files in os.walk(STORAGE_DEVICE_PATH + vmbus): # pylint: disable=W0612 root_path_parts = root.split('/') # For Gen1 VMs we only have to check for the block dir in the # current device. But for Gen2 VMs all of the disks (sda, sdb, # sr0) are presented in this device on the same SCSI controller. # Because of that we need to also read the LUN. It will be: # 0 - OS disk # 1 - Resource disk # 2 - CDROM if root_path_parts[-1] == 'block' and ( # pylint: disable=R1705 guid != gen2_device_id or root_path_parts[-2].split(':')[-1] == '1'): device = dirs[0] return device else: # older distros for d in dirs: # pylint: disable=C0103 if ':' in d and "block" == d.split(':')[0]: device = d.split(':')[1] return device except (OSError, IOError) as exc: err_msg='Error getting device for %s or %s: %s , Stack Trace: %s' % (gen1_device_prefix, gen2_device_id, str(exc),traceback.format_exc()) return None def device_for_ide_port(self): """ Return device name attached to ide port 'n'. gen1 device prefix is the prefix of the file name in which the resource disk partition is stored eg sdb gen1 is for new distros In old distros the directory name which contains resource disk partition is assigned to gen2 device id """ g0 = "00000000" gen1_device_prefix = '{0}-0001'.format(g0) self.handler.log( '[SnapshostRequest.device_for_ide_port] -> Searching gen1 prefix {0} or gen2 {1}'.format( gen1_device_prefix, GEN2_DEVICE_ID )) device = self.search_for_resource_disk( gen1_device_prefix=gen1_device_prefix, gen2_device_id=GEN2_DEVICE_ID ) self.handler.log('[SnapshotRequest.device_for_ide_port] -> Found device: {0}'.format(device)) return device def get_resource_disk_mount_point(self,option=1): # pylint: disable=R0912,R0914 try: """ if option = 0 then partition will be returned eg sdb1 if option = 1 then mount point will be returned eg /mnt/resource """ device = self.device_for_ide_port() if device is None: self.handler.log('unable to detect disk topology',True,'Error') partition = None if device is not None: partition = "{0}{1}".format(device,"1") #assuming only one resourde disk partition self.handler.log("Resource disk partition: {0} ".format(partition),True) if(option==0): return partition # find name of mount using: # grep -E "^/dev/sdb1" /proc/mounts | awk '{print $2}' # print("Found partition: {}".format(partition)) if partition is not None: p1 = subprocess.Popen(["grep", "-E", "^/dev/{}".format(partition), "/proc/mounts"], stdout=subprocess.PIPE) p2 = subprocess.check_output(["awk", '{print $2}'], stdin=p1.stdout).decode("utf-8") p1.stdout.close() v = [x for x in p2.split("\n") if x.strip()] if len(v) > 0: # print("Returning v[0]: {}".format(v[0])) return v[0] except Exception as e: self.handler.log( "[SnapshotRequest.get_resource_disk_mountpoint] -> unexpected error occured: {}\n{}".format(e, traceback.format_exc()) ) return None def get_mounts(self): try: resource_mount = self.get_resource_disk_mount_point() p1 = subprocess.Popen(["mount", "-l"], stdout=subprocess.PIPE) p2 = subprocess.Popen(["grep", "-E", "(ext4|ext3|btrfs|xfs)"], stdin=p1.stdout, stdout=subprocess.PIPE) p3 = subprocess.check_output(["awk", '{print $1" "$3}'], stdin=p2.stdout).decode("utf-8") p1.stdout.close() p2.stdout.close() # print("p3: {}".format(p3)) disks = self.get_block_devices() # print("disks: {}".format(disks)) def is_valid_mount(partition,mount_point): if resource_mount is not None and mount_point.strip() == resource_mount: return False # lsblk -ndo pkname /dev/sda1 disk = subprocess.check_output(["lsblk", "-ndo", "pkname", partition]).decode("utf-8") disk = " ".join(disk.split()) # removing any trailing or preceding newlines # print("[is_valid_disk] -> if disk: {} exists in list: {}".format(disk, disks)) if disk not in disks: return False return True mounts = [] for m in p3.split("\n"): if not m.strip(): continue m = " ".join(m.split()) # removing any preceding or trailing new lines v = m.split() # print("Post split: {}".format(v)) if len(v) != 2: continue partition = v[0] mount_point = v[1] # print("[get_mounts] -> Checking mount: {}".format(mount_point)) # print("[get_mounts] -> Checking partition: {}".format(partition)) if not is_valid_mount(partition, mount_point): continue mounts.append(mount_point) print("Mounts: {}".format(mounts)) return mounts except Exception as e: self.handler.log("[SnapshotRequest.get_mounts] -> Unexpected error: {}".format(e)) raise GetMountsError(traceback.format_exc()) def safefreeze_path(self): p = os.path.join(os.getcwd(),os.path.dirname(__file__),"safefreeze/bin/safefreeze") machine = os.uname()[-1] if IS_PYTHON3: machine = os.uname().machine if machine is not None and (machine.startswith("arm64") or machine.startswith("aarch64")): p = "safefreezeArm64/bin/safefreeze" return p def log_binary_output(self): print( "[SnapshotRequest.log_binary_output] -> ============== Binary output traces start ================= " ) while True: line=self.freeze_handler.child.stdout.readline() if IS_PYTHON3: line = str(line, encoding='utf-8', errors="backslashreplace") else: line = str(line) if("[SnapshotRequest.log_binary_output] -> Failed to open:" in line): self.mount_open_failed = True if(line != ''): self.handler.log(line.rstrip(), True) else: break print( "[SnapshotRequest.log_binary_output] -> ============== Binary output traces end ================= " ) def freeze_safe(self, args): errors = [] error_codes = [] timedout = False try: self.freeze_handler.reset_signals() self.freeze_handler.signal_receiver() sig_handle = self.freeze_handler.startproc(args) # self.handler.log( # "[SnapshotRequest.freeze_safe] -> freeze_safe after returning from startproc : sig_handle={}".format(str(sig_handle)) # ) print("[SnapshotRequest.freeze_safe] -> freeze_safe after returning from startproc : sig_handle={}".format(str(sig_handle))) if(sig_handle != 1): if (self.freeze_handler.child is not None): print("[SnapshotRequest.freeze_safe] -> calling log_binary_output") self.log_binary_output() if (sig_handle == 0): timedout = True error_msg="freeze timed-out" errors.append(error_msg) error_codes.append("FREEZE_TIMED_OUT") self.handler.log(error_msg) # elif (self.mount_open_failed == True): # error_msg=CommonVariables.unable_to_open_err_string # errors.append(error_msg) # self.handler.log(error_msg) # elif (self.isAquireLockSucceeded == False): # error_msg="Mount Points already freezed by some other processor" # errors.append(error_msg) # self.handler.log(error_msg) else: error_msg="freeze failed for some mount" errors.append(error_msg) error_codes.append("INCOMPLETE_FREEZE") self.handler.log(error_msg) except Exception as e: # self.logger.enforce_local_flag(True) error_msg='freeze failed for some mount with exception, Exception %s, stack trace: %s' % (str(e), traceback.format_exc()) errors.append(error_msg) error_codes.append("UNEXPECTED_FREEZE_EXC") self.handler.log(error_msg) finally: self.freeze_start_time = datetime.utcnow() return errors, error_codes, timedout def thaw_safe(self): errors = [] unable_to_sleep = False try: if not self.freeze_safe_active: self.freeze_end_time = datetime.utcnow() return errors, unable_to_sleep if(self.freeze_handler.child is None): print("[SnapshotRequest.thaw_safe] -> child already completed") print("[SnapshotRequest.thaw_safe] -> ****** 7. Error - Binary Process Already Completed") error_msg = 'snapshot result inconsistent' errors.append(error_msg) elif(self.freeze_handler.child.poll() is None): print("[SnapshotRequest.thaw_safe] -> child process still running") print("[SnapshotRequest.thaw_safe] -> ****** 7. Sending Thaw Signal to Binary") self.freeze_handler.child.send_signal(signal.SIGUSR1) # Will try for 30 seconds to see if freeze process has stopped for i in range(0,30): if(self.freeze_handler.child.poll() is None): print("child still running sigusr1 sent") time.sleep(1) else: break print("[SnapshotRequest.thaw_safe] -> calling log_binary_output: 1") self.log_binary_output() if(self.freeze_handler.child.returncode != 0): error_msg = '[SnapshotRequest.thaw_safe] -> snapshot result inconsistent as child returns with failure' errors.append(error_msg) print(error_msg, True, 'Error') else: self.handler.log("[SnapshotRequest.thaw_safe] -> Binary output after process end when no thaw sent: ", True) if(self.freeze_handler.child.returncode == 2): error_msg = '[SnapshotRequest.thaw_safe] -> Unable to execute sleep' errors.append(error_msg) unable_to_sleep = True else: error_msg = '[SnapshotRequest.thaw_safe] -> snapshot result inconsistent' errors.append(error_msg) print("[SnapshotRequest.thaw_safe] -> calling log_binary_output: 2") self.log_binary_output() print(error_msg, True, 'Error') finally: self.freeze_end_time = datetime.utcnow() return errors, unable_to_sleep # Uses safe_freeze binary which depends on fsfreeze # TODO: support LVM when present def freeze_mounts(self): errors = [] error_codes = [] try: timeout = self.handler.get_intvalue_from_configfile('timeout','60') args = [self.safefreeze_path(), str(timeout)] mounts = self.get_mounts() if len(mounts) == 0: self.handler.log("[SnapshotRequest.freeze_mounts] -> nothing to freeze") return False for mount in mounts: args.append(mount) errors, error_codes, timedout = self.freeze_safe(args) if len(errors) == 0 and not timedout: self.freeze_start_time = datetime.utcnow() self.freeze_safe_active = True except GetMountsError as gme: self.handler.log("[SnapshotRequest.freeze_mounts] -> get_mounts failed: {}\n{}".format(gme, traceback.format_exc())) except Exception as e: self.handler.log("[SnapshotRequest.freeze_mounts] -> unexpected error occured: {}\n{}".format(e, traceback.format_exc())) return errors, error_codes def start_snapshot(self, error_code = None, error_message = None): print("[SnapshotRequest.start_snapshot] -> Fired") errors = [] try: payload = { "snapshotId": self.snapshotId, "errMsg": "" } if error_code is not None: payload["error"] = { "code": error_code if isinstance(error_code, str) else "", "message": error_message if isinstance(error_message, str) else "", } payload["errMsg"] = error_message if isinstance(error_message, str) else "" # if IS_PYTHON3: # data = urllib.urlencode(payload).encode("utf-8") # else: # data = urllib.urlencode(payload) # print("[SnapshotRequest.start_snapshot] -> Data:{}".format(data)) if IS_PYTHON3: data = json.dumps(payload).encode("utf-8") print("[SnapshotRequest.start_snapshot] -> Data: {}".format(data)) r = request.Request( url = "{}/machine/plugins?comp=xdisksvc&type=startsnapshot".format(BASE_URI), headers = { "Content-Type": "application/json; charset=utf-8", "Content-Length": len(data), } ) else: data = json.dumps(payload) print("[SnapshotRequest.start_snapshot] -> Data: {}".format(data)) r = request.Request( url = "{}/machine/plugins?comp=xdisksvc&type=startsnapshot".format(BASE_URI), headers = { "Content-Type": "application/json; charset=utf-8" } ) conn = request.urlopen(r, timeout = 10, data = data) print("[SnapshotRequest.start_snapshot] -> Request: {}".format(r)) # if IS_PYTHON3: # conn = request.urlopen(r, timeout = 10, data = data) # else: # conn = request.urlopen(r, timeout = 10) if conn.status != 200: resp = conn.read() print("[SnapshotRequest.start_snapshot] -> unexpected status code:{}, Body: {}".format(conn.status, resp)) errors.append("STARTSNAP_UNEXPECTED_STATUS_{}".format(conn.status)) except HTTPError as herr: print("[SnapshotRequest.start_snapshot] -> startsnapshot request failed with status: {}, reason: {}".format(herr.code, herr.reason)) errors.append("STARTSNAP_HTTP_ERR") except Exception as e: print("[SnapshotRequest.start_snapshot] -> unexpected error occured: {}\n{}".format(e, traceback.format_exc())) errors.append("STARTSNAP_UNEXPECTED_EXC") return errors def end_snapshot(self, payload): errors = [] try: # if IS_PYTHON3: # data = urllib.urlencode(payload).encode("utf-8") # else: # data = urllib.urlencode(payload) # print("[SanpshotRequest.end_snapshot] -> Data:{}".format(data)) if IS_PYTHON3: data = json.dumps(payload).encode("utf-8") print("[SnapshotRequest.end_snapshot] -> Data: {}".format(data)) r = request.Request( url = "{}/machine/plugins?comp=xdisksvc&type=publishsnapshot".format(BASE_URI), headers = { "Content-Type": "application/json", "Content-Length": len(data) } ) else: data = json.dumps(payload) print("[SnapshotRequest.end_snapshot] -> Data: {}".format(data)) r = request.Request( url = "{}/machine/plugins?comp=xdisksvc&type=publishsnapshot".format(BASE_URI), headers = { "Content-Type": "application/json" } ) conn = request.urlopen(r, timeout = 10, data = data) # if IS_PYTHON3: # conn = request.urlopen(r, timeout = 10, data = data) # else: # conn = request.urlopen(r, timeout = 10) if conn.status != 200: resp = conn.read() print("[SnapshotRequest.end_snapshot] -> unexpected status code: {}, Body: {}".format(conn.status, resp)) errors.append("ENDSNAP_UNEXPECTED_STATUS_{}".format(conn.status)) except HTTPError as herr: print("[SnapshotRequest.end_snapshot] -> unexpected status code: {}, reason: {}".format(herr.code, herr.reason)) errors.append("ENDSNAP_UNEXPECTED_STATUS_{}".format(herr.code)) except Exception as e: print("[SnapshotRequest.end_snapshot] -> unexpected error occured: {}\n{}".format(e, traceback.format_exc())) errors.append("ENDSNAP_UNEXPECTED_EXC") return errors def take_snapshot(self): # self.freeze_start = datetime.utcnow() print("[SnapshotRequest.take_snapshot] -> Fired") frozen_at = 0 call_remote_end = 0 remote_call_success = False snapshot_error_code = None snapshot_error_msg = None try: errors, error_codes = self.freeze_mounts() x_errors = [] if len(errors) > 0: print("[Snapshot_Request.take_snapshot] -> self.freeze_mounts() failed") print("{}".format("\n".join(errors))) x_errors = self.start_snapshot(error_code = error_codes[0], error_message = errors[0]) snapshot_error_code = error_codes[0] snapshot_error_msg = errors[0] else: print("[Snapshot_Request.take_snapshot] -> self.freeze_mounts() success") frozen_at = datetime.utcnow() x_errors = self.start_snapshot() if len(x_errors) > 0: print("[Snapshot_Request.take_snapshot] -> calling xdisksvc failed with: {}".format(x_errors[0])) snapshot_error_code = x_errors[0] snapshot_error_msg = snapshot_error_code else: print("[Snapshot_Request.take_snapshot] -> calling xdisksvc succeeded") remote_call_success = True except Exception as e: print("[SnapshotRequest.take_snapshot] -> unexpected exception: {}\n{}".format(e, traceback.format_exc())) snapshot_error_code = "UNEXPECTED_SNAPSHOT_EXC" snapshot_error_msg = str(e) finally: call_remote_end = datetime.utcnow() self.thaw_safe() print("[SnapshotRequest.take_snapshot] -> thaw_safe executed successfully") print("[SnapshotRequest.take_snapshot] -> Outta try catch!") body = { "snapshotId": self.snapshotId, "errMsg": "", # "consistencyMode": "App", } if remote_call_success and (call_remote_end.timestamp() - frozen_at.timestamp()) < 9: print("[SanpshotRequest.take_snapshot] -> app consistency verified") else: print("[SnapshotRequest.take_snapshot] -> app consistency validation failed") body["error"] = { "code": snapshot_error_code, "message": snapshot_error_msg } body["errMsg"] = snapshot_error_msg self.end_snapshot(body) def get_snapshot_requests(handler): global BASE_URI res = { "statusCode": 0, "data": {}, } try: conn = request.urlopen(BASE_URI + "/machine/plugins?comp=xdisksvc&type=checkforsnapshot", timeout = 10) res["statusCode"] = conn.status if res["statusCode"] == 200: res["data"] = json.loads(conn.read()) return res except HTTPError as herr: res["statusCode"] = herr.code except Exception as e: handler.log("Exception making a http request", e) return res def take_new_snapshot(handler, data): try: sr = SnapshotRequest(handler, data) sr.take_snapshot() except InvalidSnapshotRequestInitError: handler.log("[take_new_snapshot] -> SnapshotRequest object initialized with invalid data: ", data) except Exception as e: handler.log("[take_new_snapshot] -> Unexpected error occurred: ", e) def main(): global SCRIPT_DIR HandlerUtil.waagent.LoggerInit('/dev/console','/dev/stdout') handler = Handler(HandlerUtil.waagent.Log, HandlerUtil.waagent.Error, CommonVariables.extension_name) starttime = time.time() while True: try: res = get_snapshot_requests(handler) print("[main] -> res: {}".format(res)) if res["statusCode"] == 200: take_new_snapshot(handler, res["data"]) elif res["statusCode"] == 404: handler.log("[main] -> no new snapshot requests at this time") else: handler.log("[main] -> invalid response code: ", res["statusCode"]) except Exception as e: handler.log("[main] -> Unexpected expcetion occured", e) time.sleep(300.0 - ((time.time() - starttime) % 300.0)) if __name__ == '__main__' : main() ================================================ FILE: VMBackup/main/handle_host_daemon.sh ================================================ #!/usr/bin/env sh pwdstr=`pwd` output=`cat $pwdstr'/HandlerEnvironment.json'` outputstr="$output" poststr=${outputstr#*logFolder\"} postsubstr=${poststr#*\"} postsubstr1=${postsubstr#*\"} resultstrlen=`expr ${#postsubstr} - 1 - ${#postsubstr1}` logfolder=$(echo $postsubstr | cut -b 1-$resultstrlen) logfile=$logfolder'/shell.log' rc=3 PIDFILE="directsnapshot.pid" if [ -e $PIDFILE ]; then pid=`cat $PIDFILE` pid=$(ps --pid $pid | tail -1 | awk '{ print $1 }') # echo $pid if echo $pid | grep -Eq '^[0-9]+$'; then echo "Process already exists" exit 0 else rm $PIDFILE fi fi pythonVersionList="python3.8 python3.7 python3.6 python3.5 python3.4 python3.3 python3 python2.7 python2.6 python2 python" for pythonVersion in ${pythonVersionList}; do cmnd="/usr/bin/${pythonVersion}" if [ -f "${cmnd}" ]; then echo "[$(date -u +"%F %H:%M:%S:%N")] ${pythonVersion} path exists" >> $logfile nohup $cmnd main/handle_host_daemon.py & pid=$(ps --pid $! | tail -1 | awk '{ print $1 }') echo $pid | tee $PIDFILE if echo $pid | grep -Eq '^[0-9]+$'; then rc=0 fi fi if [ $rc -eq 0 ] then break fi done if [ $rc -ne 0 ] && [ -f "`which python`" ] then echo "[$(date -u +"%F %H:%M:%S:%N")] python path exists" >> $logfile nohup /usr/bin/env python main/handle_host_daemon.py & pid=$(ps --pid $! | tail -1 | awk '{ print $1 }') echo $pid | tee $PIDFILE if echo $pid | grep -Eq '^[0-9]+$'; then rc=0 fi fi if [ $rc -ne 0 ] && [ -f "${pythonPath}" ] then echo "[$(date -u +"%F %H:%M:%S:%N")] python path exists" >> $logfile nohup $pythonPath main/handle_host_daemon.py & pid=$(ps --pid $! | tail -1 | awk '{ print $1 }') echo $pid | tee $PIDFILE if echo $pid | grep -Eq '^[0-9]+$'; then rc=0 fi fi if [ $rc -eq 3 ] then echo "[$(date -u +"%F %H:%M:%S:%N")] python version unknown" >> $logfile fi echo "[$(date -u +"%F %H:%M:%S:%N")] $rc returned from handle_host_daemon.py" >> $logfile exit $rc ================================================ FILE: VMBackup/main/hostsnapshotter.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os try: import urlparse as urlparser except ImportError: import urllib.parse as urlparser import traceback import datetime try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import multiprocessing as mp import json from common import CommonVariables from HttpUtil import HttpUtil from Utils import Status from Utils import HostSnapshotObjects from Utils import HandlerUtil from fsfreezer import FsFreezer import sys class HostSnapshotter(object): """description of class""" def __init__(self, logger, hostIp): self.logger = logger self.configfile='/etc/azure/vmbackup.conf' self.snapshoturi = 'http://' + hostIp + '/metadata/recsvc/snapshot/dosnapshot?api-version=2017-12-01' self.presnapshoturi = 'http://' + hostIp + '/metadata/recsvc/snapshot/presnapshot?api-version=2017-12-01' self.hutil = HandlerUtil.HandlerUtility(HandlerUtil.waagent.Log, HandlerUtil.waagent.Error, CommonVariables.extension_name) def snapshotall(self, paras, freezer, g_fsfreeze_on, taskId): result = None blob_snapshot_info_array = [] all_failed = True is_inconsistent = False unable_to_sleep = False meta_data = paras.backup_metadata if(self.snapshoturi is None): self.logger.log("Failed to do the snapshot because snapshoturi is none",False,'Error') all_failed = True try: snapshoturi_obj = urlparser.urlparse(self.snapshoturi) if(snapshoturi_obj is None or snapshoturi_obj.hostname is None): self.logger.log("Failed to parse the snapshoturi",False,'Error') all_failed = True else: diskIds = [] body_content = '' headers = {} headers['Backup'] = 'true' headers['Content-type'] = 'application/json' headers['UserAgent'] = 'VMSnapshot' settings = [] if (paras.includeLunList != None and paras.includeLunList.count != 0): diskIds = paras.includeLunList if(paras.wellKnownSettingFlags != None): for flag in paras.wellKnownSettingFlags: temp_dict = {} temp_dict[CommonVariables.key] = flag temp_dict[CommonVariables.value] = paras.wellKnownSettingFlags[flag] settings.append(temp_dict) if(paras.isVMADEEnabled == True and paras.diskEncryptionSettings): settings.append({CommonVariables.key:CommonVariables.isOsDiskADEEncrypted, CommonVariables.value:paras.isOsDiskADEEncrypted}) settings.append({CommonVariables.key:CommonVariables.areDataDisksADEEncrypted, CommonVariables.value:paras.areDataDisksADEEncrypted}) meta_data.append({CommonVariables.key:CommonVariables.diskEncryptionSettings, CommonVariables.value:paras.diskEncryptionSettings}) hostDoSnapshotRequestBodyObj = HostSnapshotObjects.HostDoSnapshotRequestBody(taskId, diskIds, settings, paras.snapshotTaskToken, meta_data, paras.instantAccessDurationMinutes) body_content = json.dumps(hostDoSnapshotRequestBodyObj, cls = HandlerUtil.ComplexEncoder) redactedRequestBodyObj = self.hutil.redact_sensitive_encryption_details(hostDoSnapshotRequestBodyObj) redacted_body_content = json.dumps(redactedRequestBodyObj, cls = HandlerUtil.ComplexEncoder) self.logger.log('Headers : ' + str(headers)) self.logger.log('Host Request body : ' + str(redacted_body_content)) http_util = HttpUtil(self.logger) self.logger.log("start calling the snapshot rest api") # initiate http call for blob-snapshot and get http response self.logger.log('****** 5. Snaphotting (Host) Started') result, httpResp, errMsg,responseBody = http_util.HttpCallGetResponse('POST', snapshoturi_obj, body_content, headers = headers, responseBodyRequired = True, isHostCall = True) self.logger.log('****** 6. Snaphotting (Host) Completed') self.logger.log("dosnapshot responseBody: " + responseBody) #performing thaw if g_fsfreeze_on : time_before_thaw = datetime.datetime.now() thaw_result, unable_to_sleep = freezer.thaw_safe() time_after_thaw = datetime.datetime.now() HandlerUtil.HandlerUtility.add_to_telemetery_data("ThawTime", str(time_after_thaw-time_before_thaw)) self.logger.log('T:S thaw result ' + str(thaw_result)) if(thaw_result is not None and len(thaw_result.errors) > 0): is_inconsistent = True # Http response check(After thawing) if(httpResp != None): HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodeDoSnapshot, str(httpResp.status)) if(int(httpResp.status) == 200 or int(httpResp.status) == 201) and (responseBody == None or responseBody == "") : self.logger.log("DoSnapshot: responseBody is empty but http status code is success") HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodeDoSnapshot, str(557)) all_failed = True elif(int(httpResp.status) == 200 or int(httpResp.status) == 201): blob_snapshot_info_array, all_failed = self.get_snapshot_info(responseBody) if(httpResp.status == 500 and not responseBody.startswith("{ \"error\"")): HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodeDoSnapshot, str(556)) all_failed = True else: # HttpCall failed HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodeDoSnapshot, str(555)) self.logger.log("dosnapshot Hitting wrong WireServer IP") except Exception as e: errorMsg = "Failed to do the snapshot in host with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg, False, 'Error') HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodeDoSnapshot, str(558)) all_failed = True return blob_snapshot_info_array, all_failed, is_inconsistent, unable_to_sleep def pre_snapshot(self, paras, taskId, fetch_disk_details = False): statusCode = 555 if(self.presnapshoturi is None): self.logger.log("Failed to do the snapshot because presnapshoturi is none",False,'Error') all_failed = True try: presnapshoturi_obj = urlparser.urlparse(self.presnapshoturi) if(presnapshoturi_obj is None or presnapshoturi_obj.hostname is None): self.logger.log("Failed to parse the presnapshoturi",False,'Error') all_failed = True else: headers = {} headers['Backup'] = 'true' headers['Content-type'] = 'application/json' headers['UserAgent'] = 'VMSnapshot' # if the vm is ade enabled and if the diskEncryptionSettings are not yet populated, then we need to fetch the disk details # or when the fetch_disk_details flag is set to true if(fetch_disk_details == True or (paras.isVMADEEnabled == True and not paras.diskEncryptionSettings)): if(fetch_disk_details != True): self.logger.log("Fetching disk details as the VM is ADE enabled and diskEncryptionSettings are not yet populated") fetch_disk_details = True preSnapshotSettings = [] temp_dict = {} temp_dict[CommonVariables.key] = CommonVariables.isVMADEEnabled temp_dict[CommonVariables.value] = paras.isVMADEEnabled preSnapshotSettings.append(temp_dict) hostPreSnapshotRequestBodyObj = HostSnapshotObjects.HostPreSnapshotRequestBody(taskId, paras.snapshotTaskToken, preSnapshotSettings) else: hostPreSnapshotRequestBodyObj = HostSnapshotObjects.HostPreSnapshotRequestBody(taskId, paras.snapshotTaskToken) body_content = json.dumps(hostPreSnapshotRequestBodyObj, cls = HandlerUtil.ComplexEncoder) self.logger.log('Headers : ' + str(headers)) self.logger.log('Host Request body : ' + str(body_content)) http_util = HttpUtil(self.logger) self.logger.log("start calling the presnapshot rest api") # initiate http call for blob-snapshot and get http response result, httpResp, errMsg,responseBody = http_util.HttpCallGetResponse('POST', presnapshoturi_obj, body_content, headers = headers, responseBodyRequired = True, isHostCall = True) if responseBody: try: response_json = json.loads(responseBody) if "bhsVersion" in response_json: self.logger.log("PreSnapshotResponse: bhsVersion: " + str(response_json["bhsVersion"])) if "nodeId" in response_json: self.logger.log("PreSnapshotResponse: nodeId: " + str(response_json["nodeId"])) if "responseTime" in response_json: self.logger.log("PreSnapshotResponse: responseTime: " + str(response_json["responseTime"])) if "result" in response_json: self.logger.log("PreSnapshotResponse: result: " + str(response_json["result"])) except Exception as e: self.logger.log("PreSnapshotResponse: Failed to parse responseBody: " + str(e)) if(httpResp != None): statusCode = httpResp.status self.logger.log("PreSnapshot: Status Code: " + str(statusCode)) if(int(statusCode) == 200 or int(statusCode) == 201) and (responseBody == None or responseBody == "") : self.logger.log("PreSnapshot:responseBody is empty but http status code is success") statusCode = 557 elif(responseBody != None): if(paras.isVMADEEnabled == True and fetch_disk_details == True): response = json.loads(responseBody) paras.isOsDiskADEEncrypted = response.get(CommonVariables.isOsDiskADEEncrypted) paras.areDataDisksADEEncrypted = response.get(CommonVariables.areDataDisksADEEncrypted) paras.diskEncryptionSettings = response.get(CommonVariables.diskEncryptionSettings) self.logger.log("PreSnapshotResponse: isOsDiskADEEncrypted: "+ str(paras.isOsDiskADEEncrypted)) self.logger.log("PreSnapshotResponse: areDataDisksADEEncrypted: "+ str(paras.areDataDisksADEEncrypted)) if paras.diskEncryptionSettings is not None: self.logger.log("PreSnapshotResponse: DiskEncryptionSettings: "+ str(len(paras.diskEncryptionSettings))) else: self.logger.log("PreSnapshotResponse: DiskEncryptionSettings are null") else: self.logger.log("PreSnapshotResponse: VM is either not ADE Enabled or disk details were not requested") elif(httpResp.status == 500 and not responseBody.startswith("{ \"error\"")): self.logger.log("BHS is not runnning on host machine") statusCode = 556 else: # HttpCall failed statusCode = 555 self.logger.log("presnapshot Hitting wrong WireServer IP") except Exception as e: errorMsg = "Failed to do the pre snapshot in host with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg, False, 'Error') statusCode = 558 HandlerUtil.HandlerUtility.add_to_telemetery_data(CommonVariables.hostStatusCodePreSnapshot, str(statusCode)) return statusCode, responseBody def get_snapshot_info(self, responseBody): blobsnapshotinfo_array = [] all_failed = True try: if(responseBody != None): json_reponseBody = json.loads(responseBody) epochTime = datetime.datetime(1970, 1, 1, 0, 0, 0) for snapshot_info in json_reponseBody['snapshotInfo']: self.logger.log("From Host- IsSuccessful:{0}, SnapshotUri:{1}, ErrorMessage:{2}, StatusCode:{3}".format(snapshot_info['isSuccessful'], snapshot_info['snapshotUri'], snapshot_info['errorMessage'], snapshot_info['statusCode'])) ddSnapshotIdentifierInfo = None if('ddSnapshotIdentifier' in snapshot_info and snapshot_info['ddSnapshotIdentifier'] != None): creationTimeString = snapshot_info['ddSnapshotIdentifier']['creationTime'] self.logger.log("creationTime string from BHS : {0} ".format(creationTimeString)) try: creationTimeObj = datetime.datetime.strptime(creationTimeString, "%Y-%m-%dT%H:%M:%S.%fZ") except: creationTimeObj = datetime.datetime.strptime(creationTimeString, "%Y-%m-%dT%H:%M:%SZ") self.logger.log("Converting the creationTime string received in UTC format to UTC Ticks") delta = creationTimeObj - epochTime timestamp = self.get_total_seconds(delta) creationTimeUTCTicks = str(int(timestamp * 1000)) instantAccessDurationMinutes = None if 'instantAccessDurationMinutes' in snapshot_info['ddSnapshotIdentifier']: instantAccessDurationMinutes = snapshot_info['ddSnapshotIdentifier']['instantAccessDurationMinutes'] ddSnapshotIdentifierInfo = HostSnapshotObjects.DDSnapshotIdentifier(creationTimeUTCTicks , snapshot_info['ddSnapshotIdentifier']['id'], snapshot_info['ddSnapshotIdentifier']['token'], instantAccessDurationMinutes) self.logger.log("ddSnapshotIdentifier Information from Host- creationTime : {0}, id : {1}, token : {2}, instantAccessDurationMinutes : {3}".format( ddSnapshotIdentifierInfo.creationTime, ddSnapshotIdentifierInfo.id, ddSnapshotIdentifierInfo.token, ddSnapshotIdentifierInfo.instantAccessDurationMinutes if ddSnapshotIdentifierInfo.instantAccessDurationMinutes is not None else 'Not Set')) else: self.logger.log("ddSnapshotIdentifier absent/None in Host Response") blobsnapshotinfo_array.append(HostSnapshotObjects.BlobSnapshotInfo(snapshot_info['isSuccessful'], snapshot_info['snapshotUri'], snapshot_info['errorMessage'], snapshot_info['statusCode'], ddSnapshotIdentifierInfo)) if (snapshot_info['isSuccessful'] == 'true'): all_failed = False except Exception as e: errorMsg = " deserialization of response body failed with error: %s, stack trace: %s" % (str(e), traceback.format_exc()) self.logger.log(errorMsg) return blobsnapshotinfo_array, all_failed def get_total_seconds(self, delta): # Check if total_seconds method exists in current Python version if hasattr(delta, 'total_seconds'): return delta.total_seconds() else: self.logger.log("Calculating total seconds manually for version compatibility.") return delta.days * 86400 + delta.seconds + delta.microseconds / 1e6 ================================================ FILE: VMBackup/main/mounts.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from os.path import * import re import sys import subprocess import types from Utils.DiskUtil import DiskUtil class Error(Exception): pass class Mount: def __init__(self, name, type, fstype, mount_point): self.name = name self.type = type self.fstype = fstype self.mount_point = mount_point self.unique_name = str(self.mount_point) + "_" + str(self.name) class Mounts: def __init__(self,patching,logger): self.mounts = [] added_mount_point_names = [] disk_util = DiskUtil.get_instance(patching,logger) # Get mount points mount_points, mount_points_info = disk_util.get_mount_points() # Get lsblk devices self.device_items = disk_util.get_device_items(None) lsblk_mounts = [] lsblk_mount_points = [] lsblk_unique_names = [] lsblk_fs_types = [] # List to hold mount-points returned from lsblk command but not reurned from mount command lsblk_mounts_not_in_mount = [] for device_item in self.device_items: mount = Mount(device_item.name, device_item.type, device_item.file_system, device_item.mount_point) lsblk_mounts.append(mount) logger.log("lsblk mount point "+str(mount.mount_point)+" added with device-name "+str(mount.name)+" and fs type "+str(mount.fstype)+", unique-name "+str(mount.unique_name), True) lsblk_mount_points.append(device_item.mount_point) lsblk_unique_names.append(mount.unique_name) lsblk_fs_types.append(device_item.file_system) # If lsblk mount is not found in "mount command" mount-list, add it to the lsblk_mounts_not_in_mount array if((device_item.mount_point not in mount_points) and (device_item.mount_point not in lsblk_mounts_not_in_mount)): lsblk_mounts_not_in_mount.append(device_item.mount_point) # Sort lsblk_mounts_not_in_mount array in ascending order lsblk_mounts_not_in_mount.sort() # Add the lsblk devices in the same order as they are returned in mount command output for mount_point_info in mount_points_info: mountPoint = mount_point_info[0] deviceNameParts = mount_point_info[1].split("/") uniqueName = str(mountPoint) + "_" + str(deviceNameParts[len(deviceNameParts)-1]) fsType = mount_point_info[2] if((mountPoint in lsblk_mount_points) and (mountPoint not in added_mount_point_names)): if (self.should_skip_fstype(str(fsType))): logger.log("######## mounts list item Skipped due to fsType, mountPoint "+str(mountPoint)+", fsType "+str(fsType)+" and unique-name "+str(uniqueName), True) else: lsblk_mounts_index = 0 try: lsblk_mounts_index = lsblk_unique_names.index(uniqueName) except ValueError as e: logger.log("######## UniqueName not found in lsblk list :" + str(uniqueName), True) lsblk_mounts_index = lsblk_mount_points.index(mountPoint) mountObj = lsblk_mounts[lsblk_mounts_index] if(mountObj.fstype is None or mountObj.fstype == "" or mountObj.fstype == " "): logger.log("fstype empty from lsblk for mount" + str(mountPoint), True) mountObj.fstype = fsType self.mounts.append(mountObj) added_mount_point_names.append(mountPoint) logger.log("mounts list item added, mount point "+str(mountObj.mount_point)+", device-name "+str(mountObj.name)+", fs-type "+str(mountObj.fstype)+", unique-name "+str(mountObj.unique_name), True) # Append all the lsblk devices corresponding to lsblk_mounts_not_in_mount list mount-points for mount_point in lsblk_mounts_not_in_mount: if((mount_point in lsblk_mount_points) and (mount_point not in added_mount_point_names)): self.mounts.append(lsblk_mounts[lsblk_mount_points.index(mount_point)]) added_mount_point_names.append(mount_point) logger.log("mounts list item added from lsblk_mounts_not_in_mount, mount point "+str(mount_point), True) added_mount_point_names.reverse() logger.log("added_mount_point_names :" + str(added_mount_point_names), True) # Reverse the mounts list self.mounts.reverse() def should_skip_fstype(self, fstype): if (fstype == 'ext3' or fstype == 'ext4' or fstype == 'xfs' or fstype == 'btrfs'): return False else: return True ================================================ FILE: VMBackup/main/parameterparser.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from common import CommonVariables import base64 import json import sys class ParameterParser(object): def __init__(self, protected_settings, public_settings, backup_logger): """ TODO: we should validate the parameter first """ self.blobs = [] self.backup_metadata = None self.public_config_obj = None self.private_config_obj = None self.blobs = None self.customSettings = None self.isVMADEEnabled = False self.snapshotTaskToken = '' self.includedDisks = None self.dynamicConfigsFromCRP = None self.disk_encryption_details = [] self.isOsDiskADEEncrypted = False self.areDataDisksADEEncrypted = False self.diskEncryptionSettings = {} self.wellKnownSettingFlags = {CommonVariables.isSnapshotTtlEnabled: False, CommonVariables.useMccfToFetchDsasForAllDisks: False, CommonVariables.useMccfForLad: False, CommonVariables.enableSnapshotExtensionPolling: False, CommonVariables.isVmmdBlobIncluded : False} settingKeysMapping= {} settingKeysMapping[CommonVariables.isSnapshotTtlEnabled.lower()] = CommonVariables.isSnapshotTtlEnabled settingKeysMapping[CommonVariables.useMccfToFetchDsasForAllDisks.lower()] = CommonVariables.useMccfToFetchDsasForAllDisks settingKeysMapping[CommonVariables.useMccfForLad.lower()] = CommonVariables.useMccfForLad settingKeysMapping[CommonVariables.enableSnapshotExtensionPolling.lower()] = CommonVariables.enableSnapshotExtensionPolling self.includeLunList = [] #To be shared with HP self.instantAccessDurationMinutes = None """ get the public configuration """ self.commandToExecute = public_settings.get(CommonVariables.command_to_execute) self.taskId = public_settings.get(CommonVariables.task_id) self.locale = public_settings.get(CommonVariables.locale) self.logsBlobUri = public_settings.get(CommonVariables.logs_blob_uri) self.statusBlobUri = public_settings.get(CommonVariables.status_blob_uri) self.commandStartTimeUTCTicks = public_settings.get(CommonVariables.commandStartTimeUTCTicks) self.vmType = public_settings.get(CommonVariables.vmType) if(CommonVariables.customSettings in public_settings.keys() and public_settings.get(CommonVariables.customSettings) is not None and public_settings.get(CommonVariables.customSettings) != ""): backup_logger.log("Reading customSettings from public_settings", True) self.customSettings = public_settings.get(CommonVariables.customSettings) elif(CommonVariables.customSettings in protected_settings.keys()): backup_logger.log("Reading customSettings from protected_settings", True) self.customSettings = protected_settings.get(CommonVariables.customSettings) self.publicObjectStr = public_settings.get(CommonVariables.object_str) if(self.publicObjectStr is not None and self.publicObjectStr != ""): if sys.version_info > (3,): decoded_public_obj_string = base64.b64decode(self.publicObjectStr) decoded_public_obj_string = decoded_public_obj_string.decode('ascii') else: decoded_public_obj_string = base64.standard_b64decode(self.publicObjectStr) decoded_public_obj_string = decoded_public_obj_string.strip() decoded_public_obj_string = decoded_public_obj_string.strip('\'') self.public_config_obj = json.loads(decoded_public_obj_string) self.backup_metadata = self.public_config_obj['backupMetadata'] if(self.logsBlobUri is None or self.logsBlobUri == ""): self.logsBlobUri = protected_settings.get(CommonVariables.logs_blob_uri) if(self.statusBlobUri is None or self.statusBlobUri == ""): self.statusBlobUri = protected_settings.get(CommonVariables.status_blob_uri) if(CommonVariables.snapshotTaskToken in self.public_config_obj.keys()): self.snapshotTaskToken = self.public_config_obj[CommonVariables.snapshotTaskToken] elif(CommonVariables.snapshotTaskToken in protected_settings.keys()): self.snapshotTaskToken = protected_settings.get(CommonVariables.snapshotTaskToken) if(CommonVariables.includedDisks in self.public_config_obj.keys()): self.includedDisks = self.public_config_obj[CommonVariables.includedDisks] if("dynamicConfigsFromCRP" in self.public_config_obj): self.dynamicConfigsFromCRP = self.public_config_obj['dynamicConfigsFromCRP'] """ first get the protected configuration """ self.privateObjectStr = protected_settings.get(CommonVariables.object_str) if(self.privateObjectStr is not None and self.privateObjectStr != ""): if sys.version_info > (3,): decoded_private_obj_string = base64.b64decode(self.privateObjectStr) decoded_private_obj_string = decoded_private_obj_string.decode('ascii') else: decoded_private_obj_string = base64.standard_b64decode(self.privateObjectStr) decoded_private_obj_string = decoded_private_obj_string.strip() decoded_private_obj_string = decoded_private_obj_string.strip('\'') self.private_config_obj = json.loads(decoded_private_obj_string) self.blobs = self.private_config_obj['blobSASUri'] try: if(self.includedDisks != None): if(CommonVariables.dataDiskLunList in self.includedDisks.keys() and self.includedDisks[CommonVariables.dataDiskLunList] != None): self.includeLunList = self.includedDisks[CommonVariables.dataDiskLunList] if(CommonVariables.isOSDiskIncluded in self.includedDisks.keys() and self.includedDisks[CommonVariables.isOSDiskIncluded] == True): self.includeLunList.append(-1) backup_logger.log("LUN list - " + str(self.includeLunList), True) if(CommonVariables.isVmmdBlobIncluded in self.includedDisks.keys() and self.includedDisks[CommonVariables.isVmmdBlobIncluded] == True): self.wellKnownSettingFlags[CommonVariables.isVmmdBlobIncluded] = True if(CommonVariables.isVMADEEnabled in self.includedDisks.keys() and self.includedDisks[CommonVariables.isVMADEEnabled] == True): self.isVMADEEnabled = True except Exception as e: errorMsg = "Exception occurred while populating includeLunList, Exception: %s" % (str(e)) backup_logger.log(errorMsg, True) if(self.dynamicConfigsFromCRP != None): try: backup_logger.log("settings received " + str(self.dynamicConfigsFromCRP), True) for config in self.dynamicConfigsFromCRP: if CommonVariables.key in config and CommonVariables.value in config: config_key = config[CommonVariables.key].lower() if(config_key in settingKeysMapping): self.wellKnownSettingFlags[settingKeysMapping[config_key]] = config[CommonVariables.value] else: backup_logger.log("The received " + str(config[CommonVariables.key]) + " is not an expected setting name.", True) else: backup_logger.log("The received dynamicConfigsFromCRP is not in expected format.", True) except Exception as e: errorMsg = "Exception occurred while populating settings, Exception: %s" % (str(e)) backup_logger.log(errorMsg, True) backup_logger.log("settings to be sent " + str(self.wellKnownSettingFlags), True) try: if(self.includedDisks != None): if(CommonVariables.instantAccessDurationMinutes in self.includedDisks.keys() and self.includedDisks[CommonVariables.instantAccessDurationMinutes] != None): self.instantAccessDurationMinutes = self.includedDisks[CommonVariables.instantAccessDurationMinutes] backup_logger.log("InstantAccessDurationMinutes = " + str(self.instantAccessDurationMinutes), True) except Exception as e: errorMsg = "Exception occurred while extracting instantAccessDurationMinutes, Exception: %s" % (str(e)) backup_logger.log(errorMsg, True) ================================================ FILE: VMBackup/main/patch/AbstractPatching.py ================================================ #!/usr/bin/python # # AbstractPatching is the base patching class of all the linux distros # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess class AbstractPatching(object): """ AbstractPatching defines a skeleton neccesary for a concrete Patching class. """ def __init__(self,distro_info): self.distro_info = distro_info self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_extras(self): pass ================================================ FILE: VMBackup/main/patch/DefaultPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class DefaultPatching(AbstractPatching): def __init__(self,logger,distro_info): super(DefaultPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' def install_extras(self): """ install the sg_dd because the default dd do not support the sparse write """ pass ================================================ FILE: VMBackup/main/patch/FreeBSDPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class FreeBSDPatching(AbstractPatching): def __init__(self,logger,distro_info): super(FreeBSDPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/local/bin/base64' self.bash_path = '/usr/local/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/sbin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/sbin/umount' def install_extras(self): """ install the sg_dd because the default dd do not support the sparse write """ pass ================================================ FILE: VMBackup/main/patch/KaliPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class KaliPatching(AbstractPatching): def __init__(self,logger,distro_info): super(KaliPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' def install_extras(self): """ install the sg_dd because the default dd do not support the sparse write """ pass ================================================ FILE: VMBackup/main/patch/NSBSDPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class NSBSDPatching(AbstractPatching): resolver = None def __init__(self,logger,distro_info): super(NSBSDPatching,self).__init__(distro_info) self.logger = logger self.usr_flag = 0 self.mount_path = '/sbin/mount' try: import dns.resolver except ImportError: raise Exception("Python DNS resolver not available. Cannot proceed!") self.resolver = dns.resolver.Resolver() servers = [] getconf_cmd = "/usr/Firewall/sbin/getconf /usr/Firewall/ConfigFiles/dns Servers | tail -n +2" getconf_p = subprocess.Popen(getconf_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) output, _ = getconf_p.communicate() output = str(output) for server in output.split("\n"): if server == '': break server = server[:-1] # remove last '=' grep_cmd = "/usr/bin/grep '{}' /etc/hosts".format(server) + " | awk '{print $1}'" grep_p = subprocess.Popen(grep_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) ip, _ = grep_p.communicate() ip = str(ip).rstrip() servers.append(ip) self.resolver.nameservers = servers dns.resolver.override_system_resolver(self.resolver) def install_extras(self): pass ================================================ FILE: VMBackup/main/patch/SuSEPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class SuSEPatching(AbstractPatching): def __init__(self,logger,distro_info): super(SuSEPatching,self).__init__(distro_info) if(distro_info[1] == "11"): self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cryptsetup_path = '/sbin/cryptsetup' self.cat_path = '/bin/cat' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_extras(self): common_extras = ['cryptsetup','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['zypper', 'install','-l', extra]))) #if(paras.filesystem == "btrfs"): # extras = ['btrfs-tools'] # for extra in extras: # print("installation for " + extra + 'result is ' + str(subprocess.call(['zypper', 'install','-l', extra]))) #pass ================================================ FILE: VMBackup/main/patch/UbuntuPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class UbuntuPatching(AbstractPatching): def __init__(self,logger,distro_info): super(UbuntuPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' def install_extras(self): """ install the sg_dd because the default dd do not support the sparse write """ if(self.distro_info[0].lower() == "ubuntu" and self.distro_info[1] == "12.04"): common_extras = ['cryptsetup-bin','lsscsi'] else: common_extras = ['cryptsetup-bin','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['apt-get', 'install','-y', extra]))) ================================================ FILE: VMBackup/main/patch/__init__.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import re import platform import traceback from patch.UbuntuPatching import UbuntuPatching from patch.debianPatching import debianPatching from patch.redhatPatching import redhatPatching from patch.centosPatching import centosPatching from patch.SuSEPatching import SuSEPatching from patch.oraclePatching import oraclePatching from patch.KaliPatching import KaliPatching from patch.DefaultPatching import DefaultPatching from patch.FreeBSDPatching import FreeBSDPatching from patch.NSBSDPatching import NSBSDPatching # Define the function in case waagent(<2.0.4) doesn't have DistInfo() def DistInfo(): try: if 'FreeBSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if 'NS-BSD' in platform.system(): release = re.sub('\\-.*$', '', str(platform.release())) distinfo = ['NS-BSD', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=0)) # remove trailing whitespace in distro name if(distinfo[0] == ''): osfile= open("/etc/os-release", "r") for line in osfile: lists=str(line).split("=") if(lists[0]== "NAME"): distname = lists[1].split("\"") distinfo[0] = distname[1] if(distinfo[0].lower() == "sles"): distinfo[0] = "SuSE" osfile.close() distinfo[0] = distinfo[0].strip() return distinfo if 'Linux' in platform.system(): distinfo = ["Default"] if "ubuntu" in platform.version().lower(): distinfo[0] = "Ubuntu" elif 'suse' in platform.version().lower(): distinfo[0] = "SuSE" elif 'centos' in platform.version().lower(): distinfo[0] = "centos" elif 'debian' in platform.version().lower(): distinfo[0] = "debian" elif 'oracle' in platform.version().lower(): distinfo[0] = "oracle" elif 'redhat' in platform.version().lower() or 'rhel' in platform.version().lower(): distinfo[0] = "redhat" elif 'kali' in platform.version().lower(): distinfo[0] = "Kali" return distinfo else: return platform.dist() except Exception as e: errMsg = 'Failed to retrieve the distinfo with error: %s, stack trace: %s' % (str(e), traceback.format_exc()) logger.log(errMsg) distinfo = ['Abstract','1.0'] return distinfo def GetMyPatching(logger): """ Return MyPatching object. NOTE: Logging is not initialized at this point. """ dist_info = DistInfo() if 'Linux' in platform.system(): Distro = dist_info[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() if 'NS-BSD' in platform.system(): Distro = platform.system() Distro = Distro.replace("-", "") Distro = Distro.strip('"') Distro = Distro.strip(' ') orig_distro = Distro patching_class_name = Distro + 'Patching' if patching_class_name not in globals(): if ('SuSE'.lower() in Distro.lower()): Distro = 'SuSE' elif ('Ubuntu'.lower() in Distro.lower()): Distro = 'Ubuntu' elif ('centos'.lower() in Distro.lower() or 'big-ip'.lower() in Distro.lower()): Distro = 'centos' elif ('debian'.lower() in Distro.lower()): Distro = 'debian' elif ('oracle'.lower() in Distro.lower()): Distro = 'oracle' elif ('redhat'.lower() in Distro.lower()): Distro = 'redhat' elif ("Kali".lower() in Distro.lower()): Distro = 'Kali' elif ('FreeBSD'.lower() in Distro.lower() or 'gaia'.lower() in Distro.lower() or 'panos'.lower() in Distro.lower()): Distro = 'FreeBSD' else: Distro = 'Default' patching_class_name = Distro + 'Patching' patchingInstance = globals()[patching_class_name](logger,dist_info) return patchingInstance, patching_class_name, orig_distro ================================================ FILE: VMBackup/main/patch/centosPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.redhatPatching import redhatPatching from common import * class centosPatching(redhatPatching): def __init__(self,logger,distro_info): super(centosPatching,self).__init__(logger,distro_info) self.logger = logger self.usr_flag = 0 if(distro_info[1] == "6.8" or distro_info[1] == "6.7" or distro_info[1] == "6.6" or distro_info[1] == "6.5" or distro_info[1] == "6.9" or distro_info[1] == "6.3"): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.usr_flag = 0 self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.usr_flag = 1 self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_extras(self): common_extras = ['cryptsetup','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) ================================================ FILE: VMBackup/main/patch/debianPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class debianPatching(AbstractPatching): def __init__(self,logger,distro_info): super(debianPatching,self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' def install_extras(self): pass ================================================ FILE: VMBackup/main/patch/oraclePatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.redhatPatching import redhatPatching from common import * class oraclePatching(redhatPatching): def __init__(self,logger,distro_info): super(oraclePatching,self).__init__(logger,distro_info) self.logger = logger if(distro_info is not None and len(distro_info) > 0 and distro_info[1].startswith("6.")): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_extras(self): common_extras = ['cryptsetup','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) ================================================ FILE: VMBackup/main/patch/redhatPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from patch.AbstractPatching import AbstractPatching from common import * class redhatPatching(AbstractPatching): def __init__(self,logger,distro_info): super(redhatPatching,self).__init__(distro_info) self.logger = logger self.usr_flag = 0 if(distro_info is not None and len(distro_info) > 0 and distro_info[1].startswith("6.")): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/bin/lsblk' self.usr_flag = 0 self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/usr/bin/lsblk' self.usr_flag = 1 self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_extras(self): common_extras = ['cryptsetup','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) ================================================ FILE: VMBackup/main/safefreeze/Makefile ================================================ CC := gcc SRCDIR := src LIBDIR := lib BINDIRNEW := binNew BINDIROLD := bin BINDIR := $(BINDIRNEW) INCDIR := include BUILDDIR := build TARGET := $(BINDIR)/safefreeze SRCEXT := c SOURCES := $(shell find $(SRCDIR) -type f -name *.$(SRCEXT)) OBJECTS := $(patsubst $(SRCDIR)/%,$(BUILDDIR)/%,$(SOURCES:.$(SRCEXT)=.o)) CFLAGS := -g LDFLAGS := -static -static-libgcc INC := -I $(INCDIR) LIB := -L $(LIBDIR) all : $(TARGET) $(TARGET): $(OBJECTS) @echo "Linking..." @mkdir -p $(BINDIR) $(CC) $^ $(LDFLAGS) -o $(TARGET) $(LIB) $(BUILDDIR)/%.o: $(SRCDIR)/%.$(SRCEXT) @mkdir -p $(BUILDDIR) @echo "Compiling..." $(CC) $(CFLAGS) $(INC) -c -o $@ $< clean: @echo "Cleaning..." $(RM) -r $(BUILDDIR) $(BINDIR) .PHONY: clean ================================================ FILE: VMBackup/main/safefreeze/src/safefreeze.c ================================================ // // Copyright 2016 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include #include #include #include #include #include #include #include #include #include #include #include #define JUMPWITHSTATUS(x) \ { \ status = (x); \ if (status) goto CLEANUP; \ } void logger(const char *logstr,...) { time_t mytime; struct tm * timeinfo; char buffer[80]; time(&mytime); timeinfo = localtime(&mytime); strftime(buffer, 80, "%F %X", timeinfo); va_list arg; int done; printf("%s ", buffer); va_start(arg, logstr); done = vfprintf(stdout, logstr, arg); va_end(arg); } int gThaw = 0; void globalSignalHandler(int signum) { if (signum == SIGUSR1) { gThaw = 1; } } void printUsage() { logger("Usage: safefreeze TimeoutInSeconds MountPoint1 [MountPoint2 [MountPoint3 [..]]]\n"); } int main(int argc, char *argv[]) { int status = EXIT_SUCCESS; int timeout = 0; int numFileSystems = 0; int *fileSystemDescriptors = NULL; int i = 0; if (argc < 3) { printUsage(); JUMPWITHSTATUS(EXIT_FAILURE); } if ((timeout = atoi(argv[1])) <= 0) { printUsage(); JUMPWITHSTATUS(EXIT_FAILURE); } numFileSystems = argc - 2; fileSystemDescriptors = (int *) malloc(sizeof(int) * numFileSystems); for (i = 0; i < numFileSystems; i++) { fileSystemDescriptors[i] = -1; } for (i = 0; i < numFileSystems; i++) { char *mountPoint = argv[i + 2]; if ((fileSystemDescriptors[i] = open(mountPoint, O_RDONLY | O_NONBLOCK)) < 0) { int errsv = errno; logger("Failed to open: %s with error: %d and error message: %s\n", mountPoint, fileSystemDescriptors[i], strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } struct stat sb; if (fstat(fileSystemDescriptors[i], &sb) == -1) { int errsv = errno; logger("Failed to stat: %s with error message: %s\n", mountPoint, strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } if ((sb.st_mode & S_IFDIR) == 0) { logger("Path not a directory: %s\n", mountPoint); JUMPWITHSTATUS(EXIT_FAILURE); } } struct sigaction globalSignalAction = {0}; globalSignalAction.sa_handler = globalSignalHandler; if (sigaction(SIGHUP, &globalSignalAction, NULL) || sigaction(SIGINT, &globalSignalAction, NULL) || sigaction(SIGQUIT, &globalSignalAction, NULL) || sigaction(SIGABRT, &globalSignalAction, NULL) || sigaction(SIGPIPE, &globalSignalAction, NULL) || sigaction(SIGTERM, &globalSignalAction, NULL) || sigaction(SIGUSR1, &globalSignalAction, NULL) || sigaction(SIGUSR2, &globalSignalAction, NULL) || sigaction(SIGTSTP, &globalSignalAction, NULL) || sigaction(SIGTTIN, &globalSignalAction, NULL) || sigaction(SIGTTOU, &globalSignalAction, NULL) ) { logger("Failed to setup signal handlers\n"); JUMPWITHSTATUS(EXIT_FAILURE); } logger("****** 2. Binary Freeze Started \n"); for (i = 0; i < numFileSystems; i++) { char *mountPoint = argv[i + 2]; logger("Freezing: %s\n", mountPoint); if (ioctl(fileSystemDescriptors[i], FIFREEZE, 0) != 0) { int errsv = errno; logger("Failed to FIFREEZE: %s with error message: %s\n", mountPoint, strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } } logger("****** 3. Binary Freeze Completed \n"); if (kill(getppid(), SIGUSR1) != 0) { logger("Failed to send FreezeCompletion to parent process\n"); JUMPWITHSTATUS(EXIT_FAILURE); } time_t starttime,currenttime; currenttime=time(NULL); starttime=time(NULL); for (i = 0; i < timeout; i++) { if (gThaw == 1 ) { logger("****** 8. Binary Thaw Signal Received \n"); break; } else { sleep(1); logger("sleep for 1 second \n"); } } currenttime=time(NULL); if (gThaw != 1 && currenttime > starttime+timeout-1) { logger("Failed to receive timely Thaw from parent process\n"); JUMPWITHSTATUS(EXIT_FAILURE); } else if (gThaw != 1) { logger("Inconsistent snapshot because of SLEEP failure \n"); JUMPWITHSTATUS(2); } CLEANUP: if (fileSystemDescriptors != NULL) { for (i = numFileSystems-1 ; i >= 0; i--) { if (fileSystemDescriptors[i] >= 0) { char *mountPoint = argv[i + 2]; logger("Thawing: %s\n", mountPoint); if (ioctl(fileSystemDescriptors[i], FITHAW, 0) != 0) { logger("Failed to FITHAW: %s with error message : %s\n", mountPoint, strerror(errno)); status = EXIT_FAILURE; } close(fileSystemDescriptors[i]); fileSystemDescriptors[i] = -1; } } free(fileSystemDescriptors); fileSystemDescriptors = NULL; } return status; } ================================================ FILE: VMBackup/main/safefreezeArm64/Makefile ================================================ CC := aarch64-linux-gnu-gcc SRCDIR := src LIBDIR := lib BINDIRNEW := binNew BINDIROLD := bin BINDIR := $(BINDIRNEW) INCDIR := include BUILDDIR := build TARGET := $(BINDIR)/safefreeze SRCEXT := c SOURCES := $(shell find $(SRCDIR) -type f -name *.$(SRCEXT)) OBJECTS := $(patsubst $(SRCDIR)/%,$(BUILDDIR)/%,$(SOURCES:.$(SRCEXT)=.o)) CFLAGS := -g LDFLAGS := -static -static-libgcc INC := -I $(INCDIR) LIB := -L $(LIBDIR) all : $(TARGET) $(TARGET): $(OBJECTS) @echo "Linking..." @mkdir -p $(BINDIR) $(CC) $^ $(LDFLAGS) -o $(TARGET) $(LIB) $(BUILDDIR)/%.o: $(SRCDIR)/%.$(SRCEXT) @mkdir -p $(BUILDDIR) @echo "Compiling..." $(CC) $(CFLAGS) $(INC) -c -o $@ $< clean: @echo "Cleaning..." $(RM) -r $(BUILDDIR) $(BINDIR) .PHONY: clean ================================================ FILE: VMBackup/main/safefreezeArm64/src/safefreeze.c ================================================ // // Copyright 2016 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include #include #include #include #include #include #include #include #include #include #include #include #define JUMPWITHSTATUS(x) \ { \ status = (x); \ if (status) goto CLEANUP; \ } void logger(const char *logstr,...) { time_t mytime; struct tm * timeinfo; char buffer[80]; time(&mytime); timeinfo = localtime(&mytime); strftime(buffer, 80, "%F %X", timeinfo); va_list arg; int done; printf("%s ", buffer); va_start(arg, logstr); done = vfprintf(stdout, logstr, arg); va_end(arg); } int gThaw = 0; void globalSignalHandler(int signum) { if (signum == SIGUSR1) { gThaw = 1; } } void printUsage() { logger("Usage: safefreeze TimeoutInSeconds MountPoint1 [MountPoint2 [MountPoint3 [..]]]\n"); } int main(int argc, char *argv[]) { int status = EXIT_SUCCESS; int timeout = 0; int numFileSystems = 0; int *fileSystemDescriptors = NULL; int i = 0; if (argc < 3) { printUsage(); JUMPWITHSTATUS(EXIT_FAILURE); } if ((timeout = atoi(argv[1])) <= 0) { printUsage(); JUMPWITHSTATUS(EXIT_FAILURE); } numFileSystems = argc - 2; fileSystemDescriptors = (int *) malloc(sizeof(int) * numFileSystems); for (i = 0; i < numFileSystems; i++) { fileSystemDescriptors[i] = -1; } for (i = 0; i < numFileSystems; i++) { char *mountPoint = argv[i + 2]; if ((fileSystemDescriptors[i] = open(mountPoint, O_RDONLY | O_NONBLOCK)) < 0) { int errsv = errno; logger("Failed to open: %s with error: %d and error message: %s\n", mountPoint, fileSystemDescriptors[i], strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } struct stat sb; if (fstat(fileSystemDescriptors[i], &sb) == -1) { int errsv = errno; logger("Failed to stat: %s with error message: %s\n", mountPoint, strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } if ((sb.st_mode & S_IFDIR) == 0) { logger("Path not a directory: %s\n", mountPoint); JUMPWITHSTATUS(EXIT_FAILURE); } } struct sigaction globalSignalAction = {0}; globalSignalAction.sa_handler = globalSignalHandler; if (sigaction(SIGHUP, &globalSignalAction, NULL) || sigaction(SIGINT, &globalSignalAction, NULL) || sigaction(SIGQUIT, &globalSignalAction, NULL) || sigaction(SIGABRT, &globalSignalAction, NULL) || sigaction(SIGPIPE, &globalSignalAction, NULL) || sigaction(SIGTERM, &globalSignalAction, NULL) || sigaction(SIGUSR1, &globalSignalAction, NULL) || sigaction(SIGUSR2, &globalSignalAction, NULL) || sigaction(SIGTSTP, &globalSignalAction, NULL) || sigaction(SIGTTIN, &globalSignalAction, NULL) || sigaction(SIGTTOU, &globalSignalAction, NULL) ) { logger("Failed to setup signal handlers\n"); JUMPWITHSTATUS(EXIT_FAILURE); } logger("****** 2. Binary Freeze Started \n"); for (i = 0; i < numFileSystems; i++) { char *mountPoint = argv[i + 2]; logger("Freezing: %s\n", mountPoint); if (ioctl(fileSystemDescriptors[i], FIFREEZE, 0) != 0) { int errsv = errno; logger("Failed to FIFREEZE: %s with error message: %s\n", mountPoint, strerror(errsv)); JUMPWITHSTATUS(EXIT_FAILURE); } } logger("****** 3. Binary Freeze Completed \n"); if (kill(getppid(), SIGUSR1) != 0) { logger("Failed to send FreezeCompletion to parent process\n"); JUMPWITHSTATUS(EXIT_FAILURE); } time_t starttime,currenttime; currenttime=time(NULL); starttime=time(NULL); for (i = 0; i < timeout; i++) { if (gThaw == 1 ) { logger("****** 8. Binary Thaw Signal Received \n"); break; } else { sleep(1); logger("sleep for 1 second \n"); } } currenttime=time(NULL); if (gThaw != 1 && currenttime > starttime+timeout-1) { logger("Failed to receive timely Thaw from parent process\n"); JUMPWITHSTATUS(EXIT_FAILURE); } else if (gThaw != 1) { logger("Inconsistent snapshot because of SLEEP failure \n"); JUMPWITHSTATUS(2); } CLEANUP: if (fileSystemDescriptors != NULL) { for (i = numFileSystems-1 ; i >= 0; i--) { if (fileSystemDescriptors[i] >= 0) { char *mountPoint = argv[i + 2]; logger("Thawing: %s\n", mountPoint); if (ioctl(fileSystemDescriptors[i], FITHAW, 0) != 0) { logger("Failed to FITHAW: %s with error message : %s\n", mountPoint, strerror(errno)); status = EXIT_FAILURE; } close(fileSystemDescriptors[i]); fileSystemDescriptors[i] = -1; } } free(fileSystemDescriptors); fileSystemDescriptors = NULL; } return status; } ================================================ FILE: VMBackup/main/taskidentity.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess import xml import xml.dom.minidom class TaskIdentity: def __init__(self): self.store_identity_file = './task_identity_FD76C85E-406F-4CFA-8EB0-CF18B123365C' def save_identity(self,identity): with open(self.store_identity_file,'w') as f: f.write(identity) def stored_identity(self): identity_stored = None if(os.path.exists(self.store_identity_file)): with open(self.store_identity_file,'r') as f: identity_stored = f.read() return identity_stored ================================================ FILE: VMBackup/main/tempPlugin/VMSnapshotScriptPluginConfig.json ================================================ { "pluginName" : "ScriptRunner", "preScriptLocation" : "", "postScriptLocation" : "", "preScriptParams" : ["", ""], "postScriptParams" : ["", ""], "preScriptNoOfRetries" : 0, "postScriptNoOfRetries" : 0, "timeoutInSeconds" : 30, "continueBackupOnFailure" : true, "fsFreezeEnabled" : true } ================================================ FILE: VMBackup/main/tempPlugin/postScript.sh ================================================ #!/bin/bash instance=$1 # variables used for returning the status of the scripts success=0 error=1 warning=2 retVal=$success exit $retVal ================================================ FILE: VMBackup/main/tempPlugin/preScript.sh ================================================ #!/bin/bash instance=$1 # variables used for returning the status of the scripts success=0 error=1 warning=2 retVal=$success exit $retVal ================================================ FILE: VMBackup/main/tempPlugin/vmbackup.conf ================================================ [SnapshotThread] fsfreeze: True ================================================ FILE: VMBackup/main/workloadPatch/CustomScripts/customscript.sql ================================================ ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/logbackup.sql ================================================ ALTER SYSTEM ARCHIVE LOG CURRENT; ALTER DATABASE BACKUP CONTROLFILE TO '&1/control.ctl'; QUIT; ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/postMysqlMaster.sql ================================================ SET GLOBAL read_only = OFF; UNLOCK TABLES; ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/postMysqlSlave.sql ================================================ START SLAVE;SELECT SLEEP(5); SET GLOBAL read_only = OFF; UNLOCK TABLES; ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/postOracleMaster.sql ================================================ REM ================================================================================ REM File: postOracleMaster.sql REM Date: 16-Sep 2020 REM Type: Oracle SQL*Plus script REM Author: Microsoft CAE team REM REM Description: REM Oracle SQL*Plus script called as an Azure Backup "post" script to REM run immediately following a backup snapshot. REM REM SQL*Plus is executed in RESTRICTED LEVEL 2 mode, which means that REM commands like HOST and SPOOL are not permitted, but commands like REM START are permitted. REM REM Modifications: REM TGorman 05oct22 v0.1 - remove external dependency on AZMESSAGE procedure REM TGorman 13dec22 v0.2 - support for DATABASE_ROLE = 'STANDBY' REM ================================================================================ REM REM ******************************************************************************** REM store script version into SQL*Plus substitution variable... REM ******************************************************************************** define V_SCRIPT_VERSION="0.2" REM REM ******************************************************************************** REM Format standard output to be terse... REM ******************************************************************************** SET ECHO OFF FEEDBACK OFF TIMING OFF PAGESIZE 0 LINESIZE 130 TRIMOUT ON TRIMSPOOL ON VERIFY OFF REM REM ******************************************************************************** REM Uncomment the following SET command to make commands, status feedback, and REM timings visible for debugging... REM ******************************************************************************** REM SET ECHO ON FEEDBACK ON TIMING ON REM REM ******************************************************************************** REM Connect this SQL*Plus session to the current database instance as SYSBACKUP... REM (be sure to leave one blank line before the CONNECT command) REM REM If databases are 11g or older, then please replace the following line with REM "CONNECT / AS SYSDBA", because the SYSBACKUP role was introduced in 12c... REM ******************************************************************************** CONNECT / AS SYSBACKUP REM CONNECT / AS SYSDBA REM REM ******************************************************************************** REM Retrieve the status of the Oracle database instance, and exit from SQL*Plus REM with SUCCESS exit status if database instance is not OPEN... REM ******************************************************************************** WHENEVER OSERROR EXIT SUCCESS WHENEVER SQLERROR EXIT SUCCESS COL STATUS NEW_VALUE V_STATUS SELECT 'STATUS='||STATUS AS STATUS FROM V$INSTANCE; EXEC IF '&&V_STATUS' <> 'STATUS=OPEN' THEN RAISE NOT_LOGGED_ON; END IF; REM REM ******************************************************************************** REM Next, if SQL*Plus has not exited as a result of the last command, now ensure that REM the failure of any command results in a FAILURE exit status from SQL*Plus... REM ******************************************************************************** WHENEVER OSERROR EXIT FAILURE WHENEVER SQLERROR EXIT FAILURE REM REM ******************************************************************************** REM Display the LOG_MODE of the database to be captured by the calling Python code... REM ******************************************************************************** SELECT 'LOG_MODE='||LOG_MODE AS LOG_MODE FROM V$DATABASE; REM REM ******************************************************************************** REM Display the DATABASE_ROLE of the database to be captured by the calling Python code... REM ******************************************************************************** SELECT 'DATABASE_ROLE='||database_role AS DATABASE_ROLE FROM V$DATABASE; REM REM ******************************************************************************** REM Enable emitting DBMS_OUTPUT to standard output... REM ******************************************************************************** SET SERVEROUTPUT ON SIZE 1000000 REM REM ******************************************************************************** REM Attempt to take the database out from BACKUP mode, which will succeed only if the REM database is presently in ARCHIVELOG mode and if the database was already in REM BACKUP mode... REM ******************************************************************************** DECLARE -- v_errcontext varchar2(128); v_timestamp varchar2(32); v_database_role varchar2(32); noArchiveLogMode exception; notInBackup exception; pragma exception_init(noArchiveLogMode, -1123); pragma exception_init(noArchiveLogMode, -1142); -- BEGIN -- v_errcontext := 'query DATABASE_ROLE'; SELECT DATABASE_ROLE INTO v_database_role FROM V$DATABASE; -- if v_database_role = 'PRIMARY' then -- v_errcontext := 'END BACKUP'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER DATABASE END BACKUP'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- end if; -- EXCEPTION -- when noArchiveLogMode then -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' in NOARCHIVELOG failed - continuing backup...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' in NOARCHIVELOG failed - continuing backup...'); -- when notInBackup then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - WARN - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed as no datafiles in BACKUP mode - continuing backup...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'WARN - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed as no datafiles in BACKUP mode - continuing backup...'); -- when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; -- END; / REM REM ******************************************************************************** REM Force a switch of the online redo logfiles, which will force a full checkpoint, REM and then archive the current logfile... REM ******************************************************************************** DECLARE -- v_errcontext varchar2(128); v_timestamp varchar2(32); v_database_role varchar2(32); noArchiveLogMode exception; pragma exception_init(noArchiveLogMode, -258); -- BEGIN -- v_errcontext := 'query DATABASE_ROLE'; SELECT DATABASE_ROLE INTO v_database_role FROM V$DATABASE; -- if v_database_role = 'PRIMARY' then -- v_errcontext := 'ARCHIVE LOG CURRENT'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER SYSTEM ARCHIVE LOG CURRENT'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- end if; -- EXCEPTION -- when noArchiveLogMode then begin -- v_errcontext := 'SWITCH LOGFILE'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER SYSTEM SWITCH LOGFILE'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- exception when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; end; -- when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup post-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; -- END; / REM REM ******************************************************************************** REM Exit from Oracle SQL*Plus with SUCCESS exit status... REM ******************************************************************************** EXIT SUCCESS ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/postPostgresMaster.sql ================================================ SELECT pg_stop_backup(); ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/preMysqlMaster.sql ================================================ FLUSH TABLES WITH READ LOCK; SET GLOBAL read_only = ON; set @query = concat("SELECT \"serverlevel\" INTO OUTFILE ",@outfile); prepare stmt from @query; execute stmt;deallocate prepare stmt; SELECT SLEEP(@timeout); ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/preMysqlSlave.sql ================================================ STOP SLAVE;SELECT SLEEP(5); FLUSH TABLES WITH READ LOCK; SET GLOBAL read_only = ON; set @query = concat("SELECT \"serverlevel\" INTO OUTFILE ",@outfile); prepare stmt from @query; execute stmt;deallocate prepare stmt; SELECT SLEEP(@timeout); ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/preOracleMaster.sql ================================================ REM ================================================================================ REM File: preOracleMaster.sql REM Date: 16-Sep 2020 REM Type: Oracle SQL*Plus script REM Author: Microsoft CAE team REM REM Description: REM Oracle SQL*Plus script called as an Azure Backup "pre" script, to REM be run immediately prior to a backup snapshot. REM REM SQL*Plus is executed in RESTRICTED LEVEL 2 mode, which means that REM commands like HOST and SPOOL are not permitted, but commands like REM START are permitted. REM REM Modifications: REM TGorman 05oct22 v0.1 - remove external dependency on AZMESSAGE procedure REM TGorman 13dec22 v0.2 - support for DATABASE_ROLE = 'STANDBY' REM ================================================================================ REM REM ******************************************************************************** REM store script version into SQL*Plus substitution variable... REM ******************************************************************************** define V_SCRIPT_VERSION="0.2" REM REM ******************************************************************************** REM Format standard output to be terse... REM ******************************************************************************** SET ECHO OFF FEEDBACK OFF TIMING OFF PAGESIZE 0 LINESIZE 130 TRIMOUT ON TRIMSPOOL ON VERIFY OFF REM REM ******************************************************************************** REM Uncomment the following SET command to make commands, status feedback, and REM timings visible for debugging... REM ******************************************************************************** REM SET ECHO ON FEEDBACK ON TIMING ON REM REM ******************************************************************************** REM Connect this SQL*Plus session to the current database instance as SYSBACKUP... REM (be sure to leave one blank line before the CONNECT command) REM REM If databases are 11g or older, then please replace the following line with REM "CONNECT / AS SYSDBA", because the SYSBACKUP role was introduced in 12c... REM ******************************************************************************** CONNECT / AS SYSBACKUP REM CONNECT / AS SYSDBA REM REM ******************************************************************************** REM Retrieve the status of the Oracle database instance, and exit from SQL*Plus REM with SUCCESS exit status if database instance is not OPEN... REM ******************************************************************************** WHENEVER OSERROR EXIT SUCCESS WHENEVER SQLERROR EXIT SUCCESS COL STATUS NEW_VALUE V_STATUS SELECT 'STATUS='||STATUS AS STATUS FROM V$INSTANCE; EXEC IF '&&V_STATUS' <> 'STATUS=OPEN' THEN RAISE NOT_LOGGED_ON; END IF; REM REM ******************************************************************************** REM Next, if SQL*Plus has not exited as a result of the last command, now ensure that REM the failure of any command results in a FAILURE exit status from SQL*Plus... REM ******************************************************************************** WHENEVER OSERROR EXIT FAILURE WHENEVER SQLERROR EXIT FAILURE REM REM ******************************************************************************** REM Display the LOG_MODE of the database to be captured by the calling Python code... REM ******************************************************************************** SELECT 'LOG_MODE='||LOG_MODE AS LOG_MODE FROM V$DATABASE; REM REM ******************************************************************************** REM Display the DATABASE_ROLE of the database to be captured by the calling Python code... REM ******************************************************************************** SELECT 'DATABASE_ROLE='||database_role AS DATABASE_ROLE FROM V$DATABASE; REM REM ******************************************************************************** REM Enable emitting DBMS_OUTPUT to standard output... REM ******************************************************************************** SET SERVEROUTPUT ON SIZE 1000000 REM REM ******************************************************************************** REM Force a switch of the online redo logfiles, which will force a full checkpoint, REM and then archive the current logfile... REM ******************************************************************************** DECLARE -- v_errcontext varchar2(128); v_timestamp varchar2(32); v_database_role varchar2(32); noArchiveLogMode exception; pragma exception_init(noArchiveLogMode, -258); -- BEGIN -- v_errcontext := 'query DATABASE_ROLE'; SELECT DATABASE_ROLE INTO v_database_role FROM V$DATABASE; -- if v_database_role = 'PRIMARY' then -- v_errcontext := 'ARCHIVE LOG CURRENT'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER SYSTEM ARCHIVE LOG CURRENT'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- end if; -- EXCEPTION -- when noArchiveLogMode then begin -- v_errcontext := 'SWITCH LOGFILE'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER SYSTEM SWITCH LOGFILE'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- exception when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; end; -- when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; -- END; / REM REM ******************************************************************************** REM Attempt to put the database into BACKUP mode, which will succeed only if the REM database is presently in ARCHIVELOG mode REM ******************************************************************************** DECLARE -- v_errcontext varchar2(128); v_timestamp varchar2(32); v_database_role varchar2(32); noArchiveLogMode exception; pragma exception_init(noArchiveLogMode, -1123); -- BEGIN -- v_errcontext := 'query DATABASE_ROLE'; SELECT DATABASE_ROLE INTO v_database_role FROM V$DATABASE; -- if v_database_role = 'PRIMARY' then -- v_errcontext := 'BEGIN BACKUP'; SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: starting ' || v_errcontext || '...'); -- execute immediate 'ALTER DATABASE BEGIN BACKUP'; -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' succeeded'); -- end if; -- EXCEPTION -- when noArchiveLogMode then -- SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' in NOARCHIVELOG failed - continuing backup...'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'INFO - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' in NOARCHIVELOG failed - continuing backup...'); -- when others then SELECT TO_CHAR(SYSDATE, 'YYYY-MM-DD HH24:MI:SS') INTO v_timestamp FROM DUAL; DBMS_OUTPUT.PUT_LINE(v_timestamp || ' - FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); SYS.DBMS_SYSTEM.KSDWRT(SYS.DBMS_SYSTEM.ALERT_FILE, 'FAIL - AzBackup pre-script v&&V_SCRIPT_VERSION: ' || v_errcontext || ' failed'); raise; -- END; / REM REM ******************************************************************************** REM Exit from Oracle SQL*Plus with SUCCESS exit status... REM ******************************************************************************** EXIT SUCCESS ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/prePostgresMaster.sql ================================================ SELECT pg_start_backup('AzureBackup'); ================================================ FILE: VMBackup/main/workloadPatch/DefaultScripts/timeoutDaemon.sh ================================================ #!/usr/bin/env sh arc=0 comand="$2" cred_string="$3" timeout="$4" scriptPath="$5" sleep $timeout if [ "$1" = "oracle" ] then cmd="$comand/sqlplus -S -R 2 /nolog @$scriptPath/postOracleMaster.sql" exec $cmd elif [ "$1" = "postgres" ] then cmd="$comand/psql $cred_string -f $scriptPath/postPostgresMaster.sql" exec $cmd else echo "`date`- incorrect workload name" fi exit $arc ================================================ FILE: VMBackup/main/workloadPatch/LogBackupPatch.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import threading import os from time import sleep try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import subprocess class LogBackupPatch: def __init__(self): self.name = "" self.cred_string = "" self.baseLocation = "" self.parameterFilePath = "" self.oracleParameter = {} self.backupSource = "" self.crontabLocation = "" self.command = "" self.confParser() self.crontabEntry() def crontabEntry(self): if os.path.exists(self.crontabLocation): crontabFile = open(self.crontabLocation, 'r') crontabCheck = crontabFile.read() else: crontabCheck = "NO CRONTAB" if 'oracle' in self.name.lower(): if 'OracleLogBackup' in str(crontabCheck): return else: os.system("echo \"*/15 * * * * python " + os.path.join(os.getcwd(), "main/workloadPatch/WorkloadUtils/OracleLogBackup.py\"") + " >> /var/spool/cron/root") return def confParser(self): configfile = '/etc/azure/workload.conf' if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_section("logbackup"): if config.has_option("workload", 'workload_name'): self.name = config.get("workload", 'workload_name') else: return None if config.has_option("workload", 'command'): self.command = config.get("workload", 'command') if config.has_option("workload", 'credString'): self.cred_string = config.get("workload", 'credString') if config.has_option("logbackup", 'parameterFilePath'): self.parameterFilePath = config.get("logbackup", 'parameterFilePath') else: return None if config.has_option("logbackup", 'baseLocation'): self.baseLocation = config.get("logbackup", 'baseLocation') else: return None if config.has_option("logbackup", 'crontabLocation'): self.crontabLocation = config.get("logbackup", 'crontabLocation') else: return ================================================ FILE: VMBackup/main/workloadPatch/WorkloadPatch.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import sys import Utils.HandlerUtil import threading import os from time import sleep import re try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers import subprocess from common import CommonVariables from workloadPatch.LogBackupPatch import LogBackupPatch class ErrorDetail: def __init__(self, errorCode, errorMsg): self.errorCode = errorCode self.errorMsg = errorMsg class WorkloadPatch: def __init__(self, logger): self.logger = logger self.name = None self.supported_workload = ["oracle", "mysql", "mariadb", "postgres"] self.command = "" self.dbnames = [] self.cred_string = "" self.ipc_folder = None self.error_details = [] self.enforce_slave_only = 0 self.role = "master" self.child = [] self.timeout = "90" self.linux_user = "root" self.sudo_user = "sudo" self.outfile = "" self.logbackup = "" self.custom_scripts_enabled = 0 self.scriptpath= "DefaultScripts" self.temp_script_folder= "/etc/azure" self.configuration_path = "" self.confParser() self.pre_database_status = "" self.pre_log_mode = "" self.post_database_status = "" self.post_log_mode = "" self.instance_list = [] def readOracleList(self,filePath): re_db = re.compile(r'^(?P(\w+)):(?P(/|\w+|\.)+)(:(\w*))?') with open(filePath, 'r') as f: for line in f: line = line.strip() re_db_match = re_db.search(line) if re_db_match: db = re_db_match.group('DB') path = re_db_match.group('PATH') curr_dict = {"sid": db, "home": path, "preSuccess": False, "postSuccess": False, "noArchive": False, "dbOpen": True} self.instance_list.append(curr_dict) def pre(self): try: self.logger.log("WorkloadPatch: Entering workload pre call") self.createTempScriptsFolder() if self.role == "master" and int(self.enforce_slave_only) == 0: if self.configuration_path: self.preInstance() elif len(self.dbnames) == 0 : #pre at server level create fork process for child and append self.preMaster() else: self.preMasterDB() # create fork process for child elif self.role == "slave": if len(self.dbnames) == 0 : #pre at server level create fork process for child and append self.preSlave() else: self.preSlaveDB() # create fork process for child else: self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidRole, "invalid role name in config")) except Exception as e: self.logger.log("WorkloadPatch: exception in pre" + str(e)) self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPreError, "Exception in pre")) def post(self): try: self.logger.log("WorkloadPatch: Entering workload post call") if self.role == "master": if len(self.instance_list) != 0: self.postInstance() elif len(self.dbnames) == 0: #post at server level to turn off readonly mode self.postMaster() else: self.postMasterDB() elif self.role == "slave": if len(self.dbnames) == 0 : #post at server level to turn on slave self.postSlave() else: self.postSlaveDB() else: self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidRole, "invalid role name in config")) #Remove the temporary scripts folder created self.removeTempScriptsFolder() except Exception as e: self.logger.log("WorkloadPatch: exception in post" + str(e)) #Remove the temporary scripts folder created self.removeTempScriptsFolder() self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPostError, "exception in processing of postscript")) def preMaster(self): global preSuccess self.logger.log("WorkloadPatch: Entering pre mode for master") if self.ipc_folder != None: self.outfile = os.path.join(self.ipc_folder, "azbackupIPC.txt") if os.path.exists(self.outfile): os.remove(self.outfile) else: self.logger.log("WorkloadPatch: File for IPC does not exist at pre") preSuccess = False if 'mysql' in self.name.lower() or 'mariadb' in self.name.lower(): self.logger.log("WorkloadPatch: Create connection string for premaster mysql") if self.outfile == "": self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadIPCDirectoryMissing, "IPC directory missing")) return None prescript = os.path.join(self.temp_script_folder, self.scriptpath + "/preMysqlMaster.sql") arg = self.sudo_user+" "+self.command+self.name+" "+self.cred_string+" -e\"set @timeout="+self.timeout+";set @outfile=\\\"\\\\\\\""+self.outfile+"\\\\\\\"\\\";source "+prescript+";\"" binary_thread = threading.Thread(target=self.thread_for_sql, args=[arg]) binary_thread.start() self.waitForPreScriptCompletion() elif 'oracle' in self.name.lower(): self.logger.log("WorkloadPatch: Pre- Inside oracle pre") preOracle = self.command + "sqlplus" + " -S -R 2 /nolog @" + os.path.join(self.temp_script_folder, self.scriptpath + "/preOracleMaster.sql ") args = "su - "+self.linux_user+" -c "+"\'"+preOracle+"\'" self.logger.log("WorkloadPatch: argument passed for pre script:"+str(args)) process = subprocess.Popen(args, stdout=subprocess.PIPE, shell=True) wait_counter = 5 while process.poll() == None and wait_counter>0: wait_counter -= 1 sleep(2) while True: line= process.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: pre completed with output "+line.rstrip(), True) else: break if('BEGIN BACKUP succeeded' in line): preSuccess = True break if('LOG_MODE=' in line): line = line.replace('\n','') line_split = line.split('=') self.logger.log("WorkloadPatch: log mode set is "+line_split[1], True) if(line_split[1] == "ARCHIVELOG"): self.pre_log_mode = "ARCHIVELOG" self.logger.log("WorkloadPatch: Archive log mode for oracle") else: self.pre_log_mode = "NOARCHIVELOG" self.logger.log("WorkloadPatch: No archive log mode for oracle") if('STATUS=' in line): line = line.replace('\n', '') line_split = line.split('=') self.logger.log("WorkloadPatch: database status is "+line_split[1], True) if(line_split[1] == "OPEN"): self.pre_database_status = "OPEN" self.logger.log("WorkloadPatch: Database is open") else:##handle other DB status if required self.pre_database_status = "NOTOPEN" self.logger.log("WorkloadPatch: Database is not open") if(self.pre_log_mode == "NOARCHIVELOG" and self.pre_database_status == "OPEN"): self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadDatabaseInNoArchiveLog, "Workload in no archive log mode")) if(preSuccess == True): self.logger.log("WorkloadPatch: pre success is true") self.timeoutDaemon() elif(self.pre_database_status == "NOTOPEN"): self.logger.log("WorkloadPatch: Database in closed status, backup can be app consistent") else: self.logger.log("WorkloadPatch: Pre failed for oracle") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPreError, "Workload Pre failed")) self.logger.log("WorkloadPatch: Pre- Exiting pre mode for master") elif 'postgres' in self.name.lower(): self.logger.log("WorkloadPatch: Pre- Inside postgres pre") prePostgres = self.command + "psql " + self.cred_string + " -f " + os.path.join(os.getcwd(), "main/workloadPatch/"+self.scriptpath+"/prePostgresMaster.sql") args = "su - "+self.linux_user+" -c "+"\'"+prePostgres+"\'" self.logger.log("WorkloadPatch: argument passed for pre script:"+str(self.linux_user)+" "+str(self.command)) process = subprocess.Popen(args,stdout=subprocess.PIPE, shell=True) wait_counter = 5 while process.poll() == None and wait_counter>0: wait_counter -= 1 sleep(2) while True: line= process.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: pre completed with output "+line.rstrip(), True) else: break self.timeoutDaemon() self.logger.log("WorkloadPatch: Pre- Exiting pre mode for master postgres") #Add new workload support here else: self.logger.log("WorkloadPatch: Unsupported workload name") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidWorkloadName, "Workload Not supported")) def postMaster(self): global daemonProcess self.logger.log("WorkloadPatch: Entering post mode for master") try: if self.ipc_folder != None and self.ipc_folder != "": #IPCm based workloads if os.path.exists(self.outfile): os.remove(self.outfile) else: self.logger.log("WorkloadPatch: File for IPC does not exist at post") if len(self.child) == 0 or self.child[0].poll() is not None: self.logger.log("WorkloadPatch: Not app consistent backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingTimeout,"not app consistent")) elif self.child[0].poll() is None: self.logger.log("WorkloadPatch: pre connection still running. Sending kill signal") self.child[0].kill() else: #non IPC based workloads if (self.pre_database_status != "NOTOPEN") and (daemonProcess is None or daemonProcess.poll() is not None): self.logger.log("WorkloadPatch: Not app consistent backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingTimeout,"not app consistent")) elif daemonProcess.poll() is None: self.logger.log("WorkloadPatch: pre connection still running. Sending kill signal") daemonProcess.kill() except Exception as e: self.logger.log("WorkloadPatch: exception in daemon process indentification" + str(e)) postSuccess = False if 'mysql' in self.name.lower() or 'mariadb' in self.name.lower(): self.logger.log("WorkloadPatch: Create connection string for post master") postscript = os.path.join(self.temp_script_folder, self.scriptpath + "/postMysqlMaster.sql") args = self.sudo_user+" "+self.command+self.name+" "+self.cred_string+" < "+postscript self.logger.log("WorkloadPatch: command to execute: "+str(self.sudo_user)+" "+str(self.command)) post_child = subprocess.Popen(args,stdout=subprocess.PIPE,stdin=subprocess.PIPE,shell=True,stderr=subprocess.PIPE) elif 'oracle' in self.name.lower(): self.logger.log("WorkloadPatch: Post- Inside oracle post") postOracle = self.command + "sqlplus" + " -S -R 2 /nolog @" + os.path.join(self.temp_script_folder, self.scriptpath + "/postOracleMaster.sql ") args = "su - "+self.linux_user+" -c "+"\'"+postOracle+"\'" self.logger.log("WorkloadPatch: argument passed for post script:"+str(args)) process = subprocess.Popen(args, stdout=subprocess.PIPE, shell=True) wait_counter = 5 while process.poll()==None and wait_counter>0: wait_counter -= 1 sleep(2) while True: line= process.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: post completed with output "+line.rstrip(), True) else: break if 'END BACKUP succeeded' in line: self.logger.log("WorkloadPatch: post succeeded") postSuccess = True break if('LOG_MODE=' in line): line = line.replace('\n','') line_split = line.split('=') self.logger.log("WorkloadPatch: log mode set is "+line_split[1], True) if(line_split[1] == "ARCHIVELOG"): self.post_log_mode = "ARCHIVELOG" self.logger.log("WorkloadPatch: Archive log mode for oracle") else: self.post_log_mode = "NOARCHIVELOG" self.logger.log("WorkloadPatch: No archive log mode for oracle") if('STATUS=' in line): line = line.replace('\n', '') line_split = line.split('=') self.logger.log("WorkloadPatch: database status is "+line_split[1], True) if(line_split[1] == "OPEN"): self.post_database_status = "OPEN" self.logger.log("WorkloadPatch: Database is open") else:##handle other DB status if required self.post_database_status = "NOTOPEN" self.logger.log("WorkloadPatch: Database is not open") if((self.pre_log_mode == "NOARCHIVELOG" and self.post_log_mode == "ARCHIVELOG") or (self.pre_log_mode == "ARCHIVELOG" and self.post_log_mode == "NOARCHIVELOG")): self.logger.log("WorkloadPatch: Database log mode changed during backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadLogModeChanged, "Database log mode changed during backup")) if(postSuccess == False): if(self.pre_database_status == "NOTOPEN" and self.post_database_status == "NOTOPEN"): self.logger.log("WorkloadPatch: Database in closed status, backup is app consistent") elif((self.pre_database_status == "OPEN" and self.post_database_status == "NOTOPEN") or (self.pre_database_status == "NOTOPEN" and self.post_database_status == "OPEN")): self.logger.log("WorkloadPatch: Database status changed during backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadDatabaseStatusChanged, "Database status changed during backup")) else: self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPostError, "Workload Post failed")) self.logger.log("WorkloadPatch: Post- Completed") self.callLogBackup() elif 'postgres' in self.name.lower(): self.logger.log("WorkloadPatch: Post- Inside postgres post") postPostgres = self.command + "psql " + self.cred_string + " -f " + os.path.join(os.getcwd(), "main/workloadPatch/"+self.scriptpath+"/postPostgresMaster.sql") args = "su - "+self.linux_user+" -c "+"\'"+postPostgres+"\'" self.logger.log("WorkloadPatch: argument passed for post script:"+str(self.linux_user)+" "+str(self.command)) process = subprocess.Popen(args,stdout=subprocess.PIPE, shell=True) wait_counter = 5 while process.poll()==None and wait_counter>0: wait_counter -= 1 sleep(2) self.logger.log("WorkloadPatch: Post- Completed") #Add new workload support here else: self.logger.log("WorkloadPatch: Unsupported workload name") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidWorkloadName, "Workload Not supported")) def preSlave(self): self.logger.log("WorkloadPatch: Entering pre mode for sloave") if self.ipc_folder != None: self.outfile = os.path.join(self.ipc_folder, "azbackupIPC.txt") if os.path.exists(self.outfile): os.remove(self.outfile) else: self.logger.log("WorkloadPatch: File for IPC does not exist at pre") if 'mysql' in self.name.lower() or 'mariadb' in self.name.lower(): self.logger.log("WorkloadPatch: Create connection string for preslave mysql") if self.outfile == "": self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadIPCDirectoryMissing, "IPC directory missing")) return None prescript = os.path.join(self.temp_script_folder, self.scriptpath + "/preMysqlSlave.sql") arg = self.sudo_user+" "+self.command+self.name+" "+self.cred_string+" -e\"set @timeout="+self.timeout+";set @outfile=\\\"\\\\\\\""+self.outfile+"\\\\\\\"\\\";source "+prescript+";\"" binary_thread = threading.Thread(target=self.thread_for_sql, args=[arg]) binary_thread.start() self.waitForPreScriptCompletion() #Add new workload support here else: self.logger.log("WorkloadPatch: Unsupported workload name") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidWorkloadName, "Workload Not supported")) def postSlave(self): self.logger.log("WorkloadPatch: Entering post mode for slave") if self.ipc_folder != None and self.ipc_folder != "":#IPCm based workloads if os.path.exists(self.outfile): os.remove(self.outfile) else: self.logger.log("WorkloadPatch: File for IPC does not exist at post") if len(self.child) == 0 or self.child[0].poll() is not None: self.logger.log("WorkloadPatch: Not app consistent backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingTimeout,"not app consistent")) return elif self.child[0].poll() is None: self.logger.log("WorkloadPatch: pre connection still running. Sending kill signal") self.child[0].kill() else: #non IPC based workloads if (self.pre_database_status != "NOTOPEN") and (daemonProcess is None or daemonProcess.poll() is not None): self.logger.log("WorkloadPatch: Not app consistent backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingTimeout,"not app consistent")) return elif daemonProcess.poll() is None: self.logger.log("WorkloadPatch: pre connection still running. Sending kill signal") daemonProcess.kill() if 'mysql' in self.name.lower() or 'mariadb' in self.name.lower(): self.logger.log("WorkloadPatch: Create connection string for post slave") postscript = os.path.join(self.temp_script_folder, self.scriptpath + "/postMysqlSlave.sql") args = self.sudo_user+" "+self.command+self.name+" "+self.cred_string+" < "+postscript self.logger.log("WorkloadPatch: command to execute: "+str(args)) post_child = subprocess.Popen(args,stdout=subprocess.PIPE,stdin=subprocess.PIPE,shell=True,stderr=subprocess.PIPE) #Add new workload support here else: self.logger.log("WorkloadPatch: Unsupported workload name") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadInvalidWorkloadName, "Workload Not supported")) def preInstance(self): if 'oracle' in self.name.lower(): self.readOracleList(self.configuration_path) for index in range(len(self.instance_list)): oracleInstance = self.instance_list[index] oracle_home = oracleInstance["home"] commandPath = os.path.join(oracle_home,'bin') + "/" self.preMasterOracleInstance(commandPath, index) def postInstance(self): if 'oracle' in self.name.lower(): for index in range(len(self.instance_list)): oracleInstance = self.instance_list[index] oracle_home = oracleInstance["home"] commandPath = os.path.join(oracle_home,'bin') + "/" if ((oracleInstance["preSuccess"] == True or oracleInstance["dbOpen"] == False)): self.postMasterOracleInstance(commandPath, index) else: if (oracleInstance["noArchive"] == True): self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadDatabaseInNoArchiveLog, "Workload in no archive log mode")) self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPreError, "Workload Pre failed for SID: " + oracleInstance["sid"])) def preMasterOracleInstance(self, commandPath, instanceIndex): global preSuccess self.logger.log("WorkloadPatch: Entering pre mode for master") preSuccess = False oracleInstance = self.instance_list[instanceIndex] self.logger.log("WorkloadPatch: Pre- Inside oracle pre for instance with SID: " + oracleInstance["sid"] + " HOME: " + oracleInstance["home"]) preOracle = commandPath + "sqlplus" + " -S -R 2 /nolog @" + os.path.join(self.temp_script_folder, self.scriptpath + "/preOracleMaster.sql ") envExport = "export ORACLE_SID=" + oracleInstance["sid"] + "; export ORACLE_HOME=" + oracleInstance["home"] + "; export PATH=" + oracleInstance["home"] + "/bin:${PATH}; export ORACLE_UNQNAME=" + oracleInstance["sid"] + "; " args = "su - "+self.linux_user+" -c "+"\'"+ envExport + preOracle+"\'" self.logger.log("WorkloadPatch: argument passed for pre script:"+str(args)) process = subprocess.Popen(args, stdout=subprocess.PIPE, shell=True) self.instance_list[instanceIndex]["pid"] = process.pid wait_counter = 5 while process.poll() == None and wait_counter>0: wait_counter -= 1 sleep(2) while True: line= process.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: pre completed with output "+line.rstrip(), True) else: break if('BEGIN BACKUP succeeded' in line): preSuccess = True break if('LOG_MODE=' in line): line = line.replace('\n','') line_split = line.split('=') self.logger.log("WorkloadPatch: log mode set is "+line_split[1], True) if(line_split[1] == "ARCHIVELOG"): self.pre_log_mode = "ARCHIVELOG" self.logger.log("WorkloadPatch: Archive log mode for oracle") else: self.pre_log_mode = "NOARCHIVELOG" self.logger.log("WorkloadPatch: No archive log mode for oracle") if('STATUS=' in line): line = line.replace('\n', '') line_split = line.split('=') self.logger.log("WorkloadPatch: database status is "+line_split[1], True) if(line_split[1] == "OPEN"): self.pre_database_status = "OPEN" self.logger.log("WorkloadPatch: Database is open") else:##handle other DB status if required self.pre_database_status = "NOTOPEN" self.instance_list[instanceIndex]["dbOpen"] = False self.logger.log("WorkloadPatch: Database is not open") if(self.pre_log_mode == "NOARCHIVELOG" and self.pre_database_status == "OPEN"): self.instance_list[instanceIndex]["noArchive"] = True if(preSuccess == True): self.logger.log("WorkloadPatch: pre success is true") self.instance_list[instanceIndex]["preSuccess"] = True self.timeoutDaemonOracleInstance(instanceIndex, commandPath) elif(self.pre_database_status == "NOTOPEN"): self.logger.log("WorkloadPatch: Database in closed status, backup can be app consistent") else: self.logger.log("WorkloadPatch: Pre failed for oracle") self.logger.log("WorkloadPatch: Pre- Exiting pre mode for master") def postMasterOracleInstance(self, commandPath, instanceIndex): global daemonProcess if "daemonProcess" in self.instance_list[instanceIndex]: daemonProcess = self.instance_list[instanceIndex]["daemonProcess"] self.logger.log("WorkloadPatch: Entering post mode for master") try: if (self.instance_list[instanceIndex]["dbOpen"] == True) and (daemonProcess is None or daemonProcess.poll() is not None): self.logger.log("WorkloadPatch: Not app consistent backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingTimeout,"not app consistent")) elif daemonProcess.poll() is None: self.logger.log("WorkloadPatch: pre connection still running. Sending kill signal") daemonProcess.kill() except Exception as e: self.logger.log("WorkloadPatch: exception in daemon process indentification" + str(e)) postSuccess = False oracleInstance = self.instance_list[instanceIndex] self.logger.log("WorkloadPatch: Post- Inside oracle post for instance with SID: " + oracleInstance["sid"] + " HOME: " + oracleInstance["home"]) postOracle = commandPath + "sqlplus" + " -S -R 2 /nolog @" + os.path.join(self.temp_script_folder, self.scriptpath + "/postOracleMaster.sql ") envExport = "export ORACLE_SID=" + oracleInstance["sid"] + "; export ORACLE_HOME=" + oracleInstance["home"] + "; export PATH=" + oracleInstance["home"] + "/bin:${PATH}; export ORACLE_UNQNAME=" + oracleInstance["sid"] + "; " args = "su - "+self.linux_user+" -c "+"\'"+ envExport + postOracle+"\'" self.logger.log("WorkloadPatch: argument passed for post script:"+str(args)) process = subprocess.Popen(args, stdout=subprocess.PIPE, shell=True) wait_counter = 5 while process.poll()==None and wait_counter>0: wait_counter -= 1 sleep(2) while True: line= process.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: post completed with output "+line.rstrip(), True) else: break if 'END BACKUP succeeded' in line: self.logger.log("WorkloadPatch: post succeeded") postSuccess = True self.instance_list[instanceIndex]["postSuccess"] = True break if('LOG_MODE=' in line): line = line.replace('\n','') line_split = line.split('=') self.logger.log("WorkloadPatch: log mode set is "+line_split[1], True) if(line_split[1] == "ARCHIVELOG"): self.post_log_mode = "ARCHIVELOG" self.logger.log("WorkloadPatch: Archive log mode for oracle") else: self.post_log_mode = "NOARCHIVELOG" self.logger.log("WorkloadPatch: No archive log mode for oracle") if('STATUS=' in line): line = line.replace('\n', '') line_split = line.split('=') self.logger.log("WorkloadPatch: database status is "+line_split[1], True) if(line_split[1] == "OPEN"): self.post_database_status = "OPEN" self.logger.log("WorkloadPatch: Database is open") else:##handle other DB status if required self.post_database_status = "NOTOPEN" self.logger.log("WorkloadPatch: Database is not open") if((oracleInstance["noArchive"] == True and self.post_log_mode == "ARCHIVELOG") or (oracleInstance["noArchive"] == False and self.post_log_mode == "NOARCHIVELOG")): self.logger.log("WorkloadPatch: Database log mode changed during backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadLogModeChanged, "Database log mode changed during backup")) if(postSuccess == False): if(oracleInstance["dbOpen"] == False and self.post_database_status == "NOTOPEN"): self.logger.log("WorkloadPatch: Database in closed status, backup is app consistent") elif((oracleInstance["dbOpen"] == True and self.post_database_status == "NOTOPEN") or (oracleInstance["dbOpen"] == False and self.post_database_status == "OPEN")): self.logger.log("WorkloadPatch: Database status changed during backup") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadDatabaseStatusChanged, "Database status changed during backup")) else: self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadPostError, "Workload Post failed")) self.logger.log("WorkloadPatch: Post- Completed") self.callLogBackup() def preMasterDB(self): pass def preSlaveDB(self): pass def postMasterDB(self): pass def postSlaveDB(self): pass def confParser(self): self.logger.log("WorkloadPatch: Entering workload config parsing") configfile = '/etc/azure/workload.conf' try: if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_section("workload"): self.logger.log("WorkloadPatch: config section present for workloads ") if config.has_option("workload", 'workload_name'): name = config.get("workload", 'workload_name') if name in self.supported_workload: self.name = name self.logger.log("WorkloadPatch: config workload command "+ self.name) else: return None else: return None if config.has_option("workload", 'command_path'): self.command = config.get("workload", 'command_path') self.command = self.command+"/" self.logger.log("WorkloadPatch: config workload command "+ self.command) if config.has_option("workload", 'credString'): self.cred_string = config.get("workload", 'credString') self.logger.log("WorkloadPatch: config workload cred_string found") elif not config.has_option("workload", 'linux_user'): self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadAuthorizationMissing, "Cred and linux user string missing")) if config.has_option("workload", 'role'): self.role = config.get("workload", 'role') self.logger.log("WorkloadPatch: config workload role "+ self.role) if config.has_option("workload", 'enforceSlaveOnly'): self.enforce_slave_only = config.get("workload", 'enforceSlaveOnly') self.logger.log("WorkloadPatch: config workload enforce_slave_only "+ self.enforce_slave_only) if config.has_option("workload", 'ipc_folder'): self.ipc_folder = config.get("workload", 'ipc_folder') self.logger.log("WorkloadPatch: config ipc folder "+ self.ipc_folder) if config.has_option("workload", 'timeout'): timeout = config.get("workload", 'timeout') if timeout != "" and timeout != None: self.timeout = timeout self.logger.log("WorkloadPatch: config timeout of pre script "+ self.timeout) if config.has_option("workload", 'linux_user'): self.linux_user = config.get("workload", 'linux_user') self.logger.log("WorkloadPatch: config linux user of pre script "+ self.linux_user) self.sudo_user = "sudo -u "+self.linux_user if config.has_option("workload", 'dbnames'): dbnames_list = config.get("workload", 'dbnames') #mydb1;mydb2;mydb3 self.dbnames = dbnames_list.split(';') if config.has_option("workload", 'customScriptEnabled'): self.custom_scripts_enabled = config.get("workload", 'customScriptEnabled') self.logger.log("WorkloadPatch: config workload customer using custom script "+ self.custom_scripts_enabled) if int(self.custom_scripts_enabled) == 1: self.scriptpath= "CustomScripts" if config.has_option("workload", 'configuration_path'): self.configuration_path = config.get("workload", 'configuration_path') self.logger.log("WorkloadPatch: config workload customer having multiple instances mentioned at path "+ self.configuration_path) if config.has_section("logbackup"): self.logbackup = "enable" self.logger.log("WorkloadPatch: Logbackup Enabled") else: self.logger.log("WorkloadPatch: workload config section missing. File system consistent backup") else: self.logger.log("WorkloadPatch: workload config file missing. File system consistent backup") except Exception as e: self.logger.log("WorkloadPatch: exception in workload conf file parsing") if(self.name != None): self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadConfParsingError, "exception in workloadconfig parsing")) def createTempScriptsFolder(self): self.logger.log("WorkloadPatch: Creating temporary scripts folder") try: originalScriptsPath = os.path.join(os.getcwd(), "main/workloadPatch/"+self.scriptpath) newScriptsPath = os.path.join(self.temp_script_folder, self.scriptpath) if (os.path.exists(self.temp_script_folder) == False): self.logger.log("WorkloadPatch: Script folder directory path not found..creating") os.makedirs(self.temp_script_folder) if (os.path.exists(newScriptsPath)): self.logger.log("WorkloadPatch: Existing temporary scripts folder found..removing") self.removeTempScriptsFolder() copyProcess = subprocess.Popen(['cp','-ar',originalScriptsPath,self.temp_script_folder]) copyProcess.wait() changeOwnerProcess = subprocess.Popen(['chown','-R',self.linux_user,newScriptsPath], stdout=subprocess.PIPE) changeOwnerProcess.wait() permissionProcess = subprocess.Popen(['chmod','-R','500',newScriptsPath], stdout=subprocess.PIPE) permissionProcess.wait() self.logger.log("WorkloadPatch: Script files copied to temporary scripts folder present at " + newScriptsPath) except Exception as e: self.logger.log("WorkloadPatch: exception in creating temporary scripts folder: " + str(e)) def removeTempScriptsFolder(self): self.logger.log("WorkloadPatch: Removing temporary scripts folder") try: newScriptsPath = os.path.join(self.temp_script_folder, self.scriptpath) removalProcess = subprocess.Popen(['rm','-rf',newScriptsPath], stdout=subprocess.PIPE) removalProcess.wait() self.logger.log("WorkloadPatch: Removed temporary scripts folder") except Exception as e: self.logger.log("WorkloadPatch: exception in removing temporary scripts folder: " + str(e)) def populateErrors(self): if len(self.error_details) > 0: errdetail = self.error_details[0] return errdetail else: return None def waitForPreScriptCompletion(self): if self.ipc_folder != None: wait_counter = 5 while len(self.child) == 0 and wait_counter > 0: self.logger.log("WorkloadPatch: child not created yet", True) wait_counter -= 1 sleep(2) if wait_counter > 0: self.logger.log("WorkloadPatch: sql subprocess Created "+str(self.child[0].pid)) else: self.logger.log("WorkloadPatch: sql connection failed") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadConnectionError, "sql connection failed")) return None wait_counter = 60 while os.path.exists(self.outfile) == False and wait_counter > 0: self.logger.log("WorkloadPatch: Waiting for sql to complete") wait_counter -= 1 sleep(2) if wait_counter > 0: self.logger.log("WorkloadPatch: pre at server level completed") else: self.logger.log("WorkloadPatch: pre failed to quiesce") self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadQuiescingError, "pre failed to quiesce")) return None def timeoutDaemon(self): global daemonProcess argsDaemon = "su - "+self.linux_user+" -c " + "'" + os.path.join(self.temp_script_folder, self.scriptpath + "/timeoutDaemon.sh")+" "+self.name+" "+self.command+" \""+self.cred_string+"\" "+self.timeout+" "+os.path.join(self.temp_script_folder, self.scriptpath + "'") devnull = open(os.devnull, 'w') daemonProcess = subprocess.Popen(argsDaemon, stdout=devnull, stderr=devnull, shell=True) wait_counter = 5 while (daemonProcess is None or daemonProcess.poll() is not None) and wait_counter > 0: self.logger.log("WorkloadPatch: daemonProcess not created yet", True) wait_counter -= 1 sleep(1) if wait_counter > 0: self.logger.log("WorkloadPatch: daemonProcess Created "+str(daemonProcess.pid)) else: while True: line= daemonProcess.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: daemon process creation failed "+line.rstrip(), True) else: break self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadConnectionError, "sql connection failed")) return None def timeoutDaemonOracleInstance(self, instanceIndex, commandPath): global daemonProcess argsDaemon = "su - "+self.linux_user+" -c " + "'" + os.path.join(self.temp_script_folder, self.scriptpath + "/timeoutDaemon.sh")+" "+self.name+" "+commandPath+" \""+self.cred_string+"\" "+self.timeout+" "+os.path.join(self.temp_script_folder, self.scriptpath + "'") devnull = open(os.devnull, 'w') oracleInstance = self.instance_list[instanceIndex] envExport = "export ORACLE_SID=" + oracleInstance["sid"] + "; export ORACLE_HOME=" + oracleInstance["home"] + "; export PATH=" + oracleInstance["home"] + "/bin:${PATH}; export ORACLE_UNQNAME=" + oracleInstance["sid"] + "; " argsDaemon = envExport + argsDaemon daemonProcess = subprocess.Popen(argsDaemon, stdout=devnull, stderr=devnull, shell=True) self.instance_list[instanceIndex]["daemonProcess"] = daemonProcess wait_counter = 5 while (daemonProcess is None or daemonProcess.poll() is not None) and wait_counter > 0: self.logger.log("WorkloadPatch: daemonProcess not created yet", True) wait_counter -= 1 sleep(1) if wait_counter > 0: self.logger.log("WorkloadPatch: daemonProcess Created "+str(daemonProcess.pid)) else: while True: line= daemonProcess.stdout.readline() line = Utils.HandlerUtil.HandlerUtility.convert_to_string(line) if(line != ''): self.logger.log("WorkloadPatch: daemon process creation failed "+line.rstrip(), True) else: break self.error_details.append(ErrorDetail(CommonVariables.FailedWorkloadConnectionError, "sql connection failed")) return None def thread_for_sql(self,args): self.logger.log("WorkloadPatch: command to execute: "+str(args)) self.child.append(subprocess.Popen(args,stdout=subprocess.PIPE,stdin=subprocess.PIPE,shell=True,stderr=subprocess.PIPE)) sleep(1) def getRole(self): return "master" def callLogBackup(self): if 'enable' in self.logbackup.lower(): self.logger.log("WorkloadPatch: Initializing logbackup") logbackupObject = LogBackupPatch() else: return ================================================ FILE: VMBackup/main/workloadPatch/WorkloadUtils/OracleLogBackup.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import sys import subprocess import threading from workloadPatch.LogbackupPatch import LogBackupPatch from time import sleep from datetime import datetime # Example of Parameter File Content: # *.db_name='CDB1' def parameterFileParser(): regX = re.compile(r"\*\..+=.+") parameterFile = open(logbackup.parameterFilePath, 'r') contents = parameterFile.read() for match in regX.finditer(contents): keyParameter = match.group().split('=')[0].lstrip('*\.') valueParameter = [name.strip('\'') for name in match.group().split('=')[1].split(',')] logbackup.oracleParameter[keyParameter] = valueParameter def setLocation(): nowTimestamp = datetime.now() nowTimestamp = nowTimestamp.strftime("%Y%m%d%H%M%S") fullPath = logbackup.baseLocation + nowTimestamp os.system('mkdir -m777 '+ fullPath) return fullPath def takeBackup(): print("logbackup: Taking a backup") backupPath = setLocation() if 'oracle' in logbackup.name.lower(): backupOracle = logbackup.command + " -s / as sysdba @" + "/var/lib/waagent/Microsoft.Azure.RecoveryServices.VMSnapshotLinux-1.0.9164.0/main/workloadPatch/scripts/logbackup.sql " + backupPath argsForControlFile = ["su", "-", logbackup.cred_string, "-c", backupOracle] snapshotControlFile = subprocess.Popen(argsForControlFile) while snapshotControlFile.poll()==None: sleep(1) recoveryFileDest = logbackup.oracleParameter['db_recovery_file_dest'] dbName = logbackup.oracleParameter['db_name'] print(' logbackup: Archive log backup started at ', datetime.now().strftime("%Y%m%d%H%M%S")) os.system('cp -R -f ' + recoveryFileDest[0] + '/' + dbName[0] + '/archivelog ' + backupPath) print(' logbackup: Archive log backup complete at ', datetime.now().strftime("%Y%m%d%H%M%S")) print("logbackup: Backup Complete") def main(): global logbackup logbackup = LogBackupPatch() parameterFileParser() takeBackup() if __name__ == "__main__": main() ================================================ FILE: VMBackup/main/workloadPatch/WorkloadUtils/OracleLogRestore.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import threading import os from time import sleep import subprocess from datetime import datetime import re try: import ConfigParser as ConfigParsers except ImportError: import configparser as ConfigParsers class LogRestore: def __init__(self): self.name = "" self.cred_string = "" self.baseLocation = "" self.parameterFilePath = "" self.oracleParameter = {} self.backupSource = "" self.crontabLocation = "" self.command = "" self.confParser() self.parameterFileParser() # Example of Parameter File Content: # *.db_name='CDB1' def parameterFileParser(self): regX = re.compile(r"\*\..+=.+") parameterFile = open(self.parameterFilePath, 'r') contents = parameterFile.read() for match in regX.finditer(contents): keyParameter = match.group().split('=')[0].lstrip('*\.') valueParameter = [name.strip('\'') for name in match.group().split('=')[1].split(',')] self.oracleParameter[keyParameter] = valueParameter # To replace the existing control files in the DB with new control files def switchControlFiles(self, backupPath): parsedControlFile = self.oracleParameter['control_files'] for location in parsedControlFile: os.system('rm -f '+location) os.system('cp -f '+ backupPath + '/control.ctl ' + location) os.system('chmod a+wrx '+location) # To replace the existing archive log files in the DB with new archive log file def switchArchiveLogFiles(self, backupPath): recoveryFileDest = self.oracleParameter['db_recovery_file_dest'] dbName = self.oracleParameter['db_name'] for location in recoveryFileDest: os.system('rm -R -f '+ location + '/' + dbName[0] +'/archivelog') os.system('cp -R -f ' + backupPath + '/archivelog ' + location + '/' + dbName[0] + '/archivelog') os.system('chmod -R a+wrx '+ location +'/' + dbName[0] + '/archivelog') # To trigger the restore of control files and archive log files def triggerRestore(self): backupPath = self.baseLocation + self.backupSource self.switchControlFiles(backupPath) self.switchArchiveLogFiles(backupPath) def confParser(self): configfile = '/etc/azure/workload.conf' if os.path.exists(configfile): config = ConfigParsers.ConfigParser() config.read(configfile) if config.has_section("logbackup"): #self.logger.log("LogRestore: config section present for workload ") if config.has_option("workload", 'workload_name'): self.name = config.get("workload", 'workload_name') #self.logger.log("LogRestore: config workload name "+ self.name) else: return None if config.has_option("workload", 'command'): self.command = config.get("workload", 'command') #self.logger.log("LogRestore: config workload command " + self.command) if config.has_option("workload", 'credString'): self.cred_string = config.get("workload", 'credString') #self.logger.log("LogRestore: config workload cred_string " + self.cred_string) if config.has_option("logbackup", 'parameterFilePath'): self.parameterFilePath = config.get("logbackup", 'parameterFilePath') #self.logger.log("LogRestore: config logbackup parameter file path: " + self.parameterFilePath) else: return None if config.has_option("logbackup", 'baseLocation'): self.baseLocation = config.get("logbackup", 'baseLocation') #self.logger.log("LogRestore: config logbackup base location: " + self.baseLocation) else: return None if config.has_option("logbackup", 'crontabLocation'): self.crontabLocation = config.get("logbackup", 'crontabLocation') #self.logger.log("LogRestore: config logbackup crontab location: " + self.crontabLocation) else: return #self.logger.log("No matching workload config found") def main(): oracleLogRestore = LogRestore() os.system('ls -lrt ' + oracleLogRestore.baseLocation) oracleLogRestore.backupSource = input("Enter the timestamp: ") oracleLogRestore.triggerRestore() if __name__ == "__main__": main() ================================================ FILE: VMBackup/main/workloadPatch/WorkloadUtils/workload.conf ================================================ [workload] #workload_name valid values- mysql, oracle, mariadb, postgres workload_name = command_path = credString = ipc_folder = timeout = linux_user = ================================================ FILE: VMBackup/main/workloadPatch/__init__.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ ================================================ FILE: VMBackup/manifest.xml ================================================ Microsoft.Azure.RecoveryServices VMSnapshotLinux 1.0.9184.0 VmRole Windows Azure VMBackup Extension for Linux IaaS true https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions true Linux Microsoft Open Source Technology Center ================================================ FILE: VMBackup/mkstub.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import json import os import shutil from main.common import CommonVariables def copytree(src,dst): names = os.listdir(src) if(os.path.isdir(dst) != True): os.makedirs(dst) for name in names: srcname = os.path.join(src, name) dstname = os.path.join(dst, name) if os.path.isdir(srcname): copytree(srcname, dstname) else: # Will raise a SpecialFileError for unsupported file types shutil.copy2(srcname, dstname) target_zip_file_location = './dist/' target_folder_name = CommonVariables.extension_name + '-' + str(CommonVariables.extension_version) target_zip_file_path = target_zip_file_location + target_folder_name + '.zip' final_folder_path = target_zip_file_location + target_folder_name copytree(final_folder_path, '/var/lib/waagent/' + target_folder_name) """ we should also build up a HandlerEnvironment.json """ manifest_obj = [{ "name": CommonVariables.extension_name, "seqNo": "1", "version": 1.0, "handlerEnvironment": { "logFolder": "/var/log/azure/" + CommonVariables.extension_name + "/" + str(CommonVariables.extension_version), "configFolder": "/var/lib/waagent/" + CommonVariables.extension_name + "-" + str(CommonVariables.extension_version) + "/config", "statusFolder": "/var/lib/waagent/" + CommonVariables.extension_name + "-" + str(CommonVariables.extension_version) + "/status", "heartbeatFile": "/var/lib/waagent/" + CommonVariables.extension_name + "-" + str(CommonVariables.extension_version) + "/heartbeat.log" } }] manifest_str = json.dumps(manifest_obj, sort_keys = True, indent = 4) manifest_file = open('/var/lib/waagent/' + target_folder_name + "/HandlerEnvironment.json", "w") manifest_file.write(manifest_str) manifest_file.close() ================================================ FILE: VMBackup/references ================================================ Utils/ ================================================ FILE: VMBackup/setup.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # To build: # python setup.py sdist # # To install: # python setup.py install # # To register (only needed once): # python setup.py register # # To upload: # python setup.py sdist upload try: from setuptools import setup except ImportError: from distutils.core import setup import os import shutil import tempfile import json import sys import subprocess import shutil import time from subprocess import call from zipfile import ZipFile from main.common import CommonVariables packages_array = [] main_folder = 'main' main_entry = main_folder + '/handle.sh' binary_entry = main_folder + '/safefreeze' arm64_binary_entry = main_folder + '/safefreezeArm64' packages_array.append(main_folder) plugin_folder = main_folder + '/tempPlugin' plugin_conf = main_folder + '/VMSnapshotPluginHost.conf' severity_json = main_folder + '/LogSeverity.json' patch_folder = main_folder + '/patch' packages_array.append(patch_folder) workloadpatch_folder = main_folder + '/workloadPatch' workloadutils_folder = main_folder + '/workloadPatch/WorkloadUtils' workloadscripts_folder = main_folder + '/workloadPatch/DefaultScripts' workload_customscripts_folder = main_folder + '/workloadPatch/CustomScripts' sqlfilelists=os.listdir(workloadscripts_folder) custom_sqlfilelists=os.listdir(workload_customscripts_folder) packages_array.append(workloadpatch_folder) manifest = "manifest.xml" """ copy the dependency to the local """ """ copy the utils lib to local """ target_utils_path = main_folder + '/' + CommonVariables.utils_path_name #if os.path.isdir(target_utils_path): # shutil.rmtree(target_utils_path) #print('copying') #shutil.copytree ('../' + CommonVariables.utils_path_name, target_utils_path) #print('copying end') packages_array.append(target_utils_path) """ copy the NodeBased lib to local """ target_snapshot_service_path = main_folder + '/' + CommonVariables.snapshot_service_path_name packages_array.append(target_snapshot_service_path) polling_service_metadata = target_snapshot_service_path + '/service_metadata.json' polling_service_readme = target_snapshot_service_path + '/README.md' """ generate the HandlerManifest.json file. """ manifest_obj = [{ "name": CommonVariables.extension_name, "version": CommonVariables.extension_version, "handlerManifest": { "installCommand": main_entry + " install", "uninstallCommand": main_entry + " uninstall", "updateCommand": main_entry + " update", "enableCommand": main_entry + " enable", "disableCommand": main_entry + " disable", "rebootAfterInstall": False, "reportHeartbeat": False } }] manifest_str = json.dumps(manifest_obj, sort_keys = True, indent = 4) manifest_file = open("HandlerManifest.json", "w") manifest_file.write(manifest_str) manifest_file.close() """ generate the safe freeze binary """ cur_dir = os.getcwd() os.chdir("./main/safefreeze") chil = subprocess.Popen(["make"], stdout=subprocess.PIPE) process_wait_time = 5 while(process_wait_time >0 and chil.poll() is None): time.sleep(1) process_wait_time -= 1 os.chdir(cur_dir) ''' due to the lack of cross-compilation support in Mariner. It would not be able to get the binaries, so will be using the older binaries To Do : Once Mariner starts supporting cross-compilation, will have to modify any necessary scripts to enable the generation of safefreezeARM64 binaries. ''' """ generate the ARM64 safe freeze binary """ cur_dir = os.getcwd() os.chdir("./main/safefreezeArm64") chil = subprocess.Popen(["make"], stdout=subprocess.PIPE) process_wait_time = 5 while(process_wait_time >0 and chil.poll() is None): time.sleep(1) process_wait_time -= 1 os.chdir(cur_dir) """ setup script, to package the files up """ setup(name = CommonVariables.extension_name, version = CommonVariables.extension_zip_version, description=CommonVariables.extension_description, license='Apache License 2.0', author='Microsoft Corporation', author_email='andliu@microsoft.com', url='https://github.com/Azure/azure-linux-extensions', classifiers = ['Development Status :: 5 - Production/Stable', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: SQL', 'Programming Language :: PL/SQL'], packages = packages_array ) """ unzip the package files and re-package it. """ target_zip_file_location = './dist/' target_folder_name = CommonVariables.extension_name + '-' + CommonVariables.extension_zip_version target_zip_file_path = target_zip_file_location + target_folder_name + '.zip' target_zip_file = ZipFile(target_zip_file_path) target_zip_file.extractall(target_zip_file_location) def dos2unix(src): args = ["dos2unix",src] devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) print('dos2unix %s ' % (src)) child.wait() def zip(src, dst): zf = ZipFile("%s" % (dst), "w") abs_src = os.path.abspath(src) for dirname, subdirs, files in os.walk(src): for filename in files: absname = os.path.abspath(os.path.join(dirname, filename)) dos2unix(absname) arcname = absname[len(abs_src) + 1:] print('zipping %s as %s' % (os.path.join(dirname, filename), arcname)) zf.write(absname, arcname) zf.close() def copybinary(src, dst): shutil.copytree(src, dst) def copy(src, dst): shutil.copy2(src, dst) final_folder_path = target_zip_file_location + target_folder_name final_binary_path= final_folder_path + '/main/safefreeze' final_Arm64binary_path= final_folder_path + '/main/safefreezeArm64' final_plugin_path = final_folder_path + '/main/tempPlugin' final_workloadscripts_path = final_folder_path + '/main/workloadPatch/DefaultScripts' final_workload_customscripts_path = final_folder_path + '/main/workloadPatch/CustomScripts' final_workloadutils_path = final_folder_path + '/main/workloadPatch/WorkloadUtils' copybinary(binary_entry, final_binary_path) copybinary(arm64_binary_entry, final_Arm64binary_path) copybinary(plugin_folder, final_plugin_path) copybinary(workloadscripts_folder, final_workloadscripts_path) copybinary(workload_customscripts_folder, final_workload_customscripts_path) copybinary(workloadutils_folder, final_workloadutils_path) final_main_folder = final_folder_path + '/main' final_snapshot_service_path = final_main_folder + '/' + CommonVariables.snapshot_service_path_name copy(plugin_conf, final_main_folder) copy(severity_json, final_main_folder) copy(polling_service_metadata, final_snapshot_service_path) copy(polling_service_readme, final_snapshot_service_path) copy(manifest, final_folder_path) copy(main_entry, final_main_folder) zip(final_folder_path, target_zip_file_path) ================================================ FILE: VMBackup/test/handle.py ================================================ #!/usr/bin/env python # #CustomScript extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import array import base64 import os import os.path import re import string import subprocess import sys import imp import shlex import traceback import urllib2 import urlparse import datetime import math def main(): ticks = 635798839149570996 commandStartTime = datetime.datetime(1, 1, 1) + datetime.timedelta(microseconds = ticks/10) utcNow = datetime.datetime.utcnow() timespan = utcNow-commandStartTime print(str(timespan.total_seconds())) total_span_in_seconds = timespan.days * 24 * 60 * 60 + timespan.seconds print(str(total_span_in_seconds)) if __name__ == '__main__' : main() ================================================ FILE: VMBackup/test/install_python2.6.sh ================================================ #!/bin/bash # Function to print messages print_message() { echo "----------------------------------------" echo "$1" echo "----------------------------------------" } # Check if Python 2.6 is already installed if command -v python2.6 &> /dev/null; then PYTHON_VERSION=$(python2.6 --version 2>&1) # Capture version output print_message "Python 2.6 is already installed. Version: $PYTHON_VERSION" exit 0 fi # Update the package list print_message "Updating package list..." sudo apt update # Install required packages for building Python print_message "Installing required packages..." if ! sudo apt install -y build-essential checkinstall \ libreadline-dev libncurses-dev libssl-dev \ libsqlite3-dev tk-dev libgdbm-dev libc6-dev libbz2-dev; then echo "Error: Failed to install required packages." exit 1 fi print_message "Checking for libreadline installation..." dpkg -l | grep libreadline || echo "libreadline not found." print_message "Changing directory to /tmp..." cd /tmp print_message "Downloading Python 2.6.6 source code..." if ! wget https://www.python.org/ftp/python/2.6.6/Python-2.6.6.tgz; then echo "Error: Failed to download Python 2.6.6 source code." exit 1 fi # Extract the downloaded tarball print_message "Extracting Python 2.6.6..." if ! tar -xzf Python-2.6.6.tgz; then echo "Error: Failed to extract Python 2.6.6." exit 1 fi # Change directory to the extracted folder cd Python-2.6.6 print_message "Configuring Python build with optimizations..." if ! ./configure --enable-optimizations; then echo "Error: Configuration of Python build failed." exit 1 fi # Compile the source code print_message "Compiling Python 2.6.6. This may take a while..." if ! make; then echo "Error: Compilation of Python 2.6.6 failed." exit 1 fi print_message "Installing Python 2.6..." if ! sudo make altinstall; then echo "Error: Installation of Python 2.6 failed." exit 1 fi print_message "Verifying the installation of Python 2.6..." if command -v python2.6 &> /dev/null; then python2.6 --version else echo "Error: Python 2.6 installation was not successful." exit 1 fi print_message "Creating a symbolic link for python2..." if ! sudo ln -s /usr/local/bin/python2.6 /usr/bin/python2; then echo "Error: Failed to create a symbolic link for python2." exit 1 fi print_message "Python 2.6 installation completed successfully." ================================================ FILE: VMEncryption/.vscode/settings.json ================================================ { "python.linting.pylintEnabled": false, "python.linting.flake8Enabled": true, "python.linting.flake8Args": ["--max-line-length=300"], "python.linting.enabled": true } ================================================ FILE: VMEncryption/MANIFEST.in ================================================ include HandlerManifest.json manifest.xml extension_shim.sh recursive-include main/oscrypto/91ade *.sh recursive-include main/oscrypto/91ade *.rules recursive-include main/oscrypto/rhel_68/encryptpatches *.patch recursive-include main/oscrypto/centos_68/encryptpatches *.patch recursive-include main/oscrypto/ubuntu_1604/encryptpatches *.patch recursive-include main/oscrypto/ubuntu_1604/encryptscripts *.sh recursive-include main/oscrypto/ubuntu_1404/encryptpatches *.patch recursive-include main/oscrypto/ubuntu_1404/encryptscripts *.sh prune test ================================================ FILE: VMEncryption/ReleaseNotes.txt ================================================ (0.1.0.999345) -Fix disable after EFA encryption and stop-start -Fix Ubuntu 14 unmount oldroot sequence -Fix missing Python-six module for CentOS 6.8 ================================================ FILE: VMEncryption/Test-AzureRmVMDiskEncryptionExtension.ps1 ================================================ Param( [Parameter(Mandatory=$true)] [string] $SubscriptionId, [Parameter(Mandatory=$true)] [string] $AadClientId, [Parameter(Mandatory=$true)] [string] $AadClientSecret, [Parameter(Mandatory=$true)] [string] $ResourcePrefix, [Parameter(Mandatory=$true)] [string] $Username, [Parameter(Mandatory=$true)] [string] $Password, [string] $ExtensionName="AzureDiskEncryptionForLinux", [string] $SshPubKey, [string] $SshPrivKeyPath, [string] $Location="eastus", [string] $VolumeType="data", [string] $GalleryImage="RedHat:RHEL:7.2", [string] $VMSize="Standard_D2", [switch] $DryRun=$false, [switch] $Force=$false ) $ErrorActionPreference = "Stop" Set-AzureRmContext -SubscriptionId $SubscriptionId Write-Host "Set AzureRmContext successfully" ## Resource Group $global:ResourceGroupName = $ResourcePrefix + "ResourceGroup" if(!$DryRun) { New-AzureRmResourceGroup -Name $ResourceGroupName -Location $Location } Write-Host "Created ResourceGroup successfully: $ResourceGroupName" ## KeyVault $global:KeyVaultName = $ResourcePrefix + "KeyVault" if(!$DryRun) { $global:KeyVault = New-AzureRmKeyVault -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -Location $Location } else { $global:KeyVault = Get-AzureRmKeyVault -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName } Write-Host "Created KeyVault successfully: $KeyVaultName" if(!$DryRun) { Set-AzureRmKeyVaultAccessPolicy -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -ServicePrincipalName $AadClientId -PermissionsToKeys all -PermissionsToSecrets all Set-AzureRmKeyVaultAccessPolicy -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -EnabledForDiskEncryption } Write-Host "Set AzureRmKeyVaultAccessPolicy successfully" if(!$DryRun) { Add-AzureKeyVaultKey -VaultName $KeyVaultName -Name "keyencryptionkey" -Destination Software } Write-Host "Added AzureRmKeyVaultKey successfully" $global:KeyEncryptionKey = Get-AzureKeyVaultKey -VaultName $KeyVault.OriginalVault.Name -Name "keyencryptionkey" Write-Host "Fetched KeyEncryptionKey successfully" ## Storage $global:StorageName = ($ResourcePrefix + "Storage").ToLower() $global:StorageType = "Standard_GRS" $global:ContainerName = "vhds" if(!$DryRun) { $global:StorageAccount = New-AzureRmStorageAccount -ResourceGroupName $ResourceGroupName -Name $StorageName -Type $StorageType -Location $Location } else { $global:StorageAccount = Get-AzureRmStorageAccount -ResourceGroupName $ResourceGroupName -Name $StorageName } Write-Host "Created StorageAccount successfully: $StorageName" ## Network $global:PublicIpName = $ResourcePrefix + "PublicIp" $global:InterfaceName = $ResourcePrefix + "NetworkInterface" $global:SubnetName = $ResourcePrefix + "Subnet" $global:VNetName = $ResourcePrefix + "VNet" $global:VNetAddressPrefix = "10.0.0.0/16" $global:VNetSubnetAddressPrefix = "10.0.0.0/24" $global:DomainNameLabel = ($ResourcePrefix + "VM").ToLower() if(!$DryRun) { $global:PublicIp = New-AzureRmPublicIpAddress -Name $PublicIpName -ResourceGroupName $ResourceGroupName -Location $Location -AllocationMethod Dynamic -DomainNameLabel $DomainNameLabel } else { $global:PublicIp = Get-AzureRmPublicIpAddress -Name $PublicIpName -ResourceGroupName $ResourceGroupName } Write-Host "Created PublicIp successfully: " $PublicIp.DnsSettings.Fqdn.ToString() if(!$DryRun) { $global:SubnetConfig = New-AzureRmVirtualNetworkSubnetConfig -Name $SubnetName -AddressPrefix $VNetSubnetAddressPrefix } Write-Host "Created SubnetConfig successfully: $SubnetName" if(!$DryRun) { $global:VNet = New-AzureRmVirtualNetwork -Name $VNetName -ResourceGroupName $ResourceGroupName -Location $Location -AddressPrefix $VNetAddressPrefix -Subnet $SubnetConfig } else { $global:VNet = Get-AzureRmVirtualNetwork -Name $VNetName -ResourceGroupName $ResourceGroupName $global:SubnetConfig = Get-AzureRmVirtualNetworkSubnetConfig -Name $SubnetName -VirtualNetwork $VNet } Write-Host "Created AzureRmVirtualNetwork successfully: $VNetName" if(!$DryRun) { $global:Interface = New-AzureRmNetworkInterface -Name $InterfaceName -ResourceGroupName $ResourceGroupName -Location $Location -SubnetId $VNet.Subnets[0].Id -PublicIpAddressId $PublicIp.Id } else { $global:Interface = Get-AzureRmNetworkInterface -Name $InterfaceName -ResourceGroupName $ResourceGroupName } Write-Host "Created AzureNetworkInterface successfully: $InterfaceName" ## Compute $global:VMName = $ResourcePrefix + "VM" $global:ComputerName = $ResourcePrefix + "VM" $global:OSDiskName = $VMName + "OsDisk" $global:OSDiskUri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $OSDiskName + ".vhd" $global:DataDisk1Name = $VMName + "DataDisk1" $global:DataDisk1Uri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $DataDisk1Name + ".vhd" $global:DataDisk2Name = $VMName + "DataDisk2" $global:DataDisk2Uri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $DataDisk2Name + ".vhd" ## Setup local VM object $SecString = ($Password | ConvertTo-SecureString -AsPlainText -Force) $Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList @($Username, $SecString) Write-Host "Created credentials successfully" $global:VirtualMachine = New-AzureRmVMConfig -VMName $VMName -VMSize $VMSize Write-Host "Created AzureRmVMConfig successfully" $VirtualMachine = Set-AzureRmVMOperatingSystem -VM $VirtualMachine -Linux -ComputerName $ComputerName -Credential $Credential Write-Host "Set AzureRmVMOperatingSystem successfully" $PublisherName = $GalleryImage.Split(":")[0] $Offer = $GalleryImage.Split(":")[1] $Skus = $GalleryImage.Split(":")[2] Write-Host "PublisherName: $PublisherName, Offer: $Offer, Skus: $Skus" $VirtualMachine = Set-AzureRmVMSourceImage -VM $VirtualMachine -PublisherName $PublisherName -Offer $Offer -Skus $Skus -Version "latest" Write-Host "Set AzureVMSourceImage successfully" $VirtualMachine = Add-AzureRmVMNetworkInterface -VM $VirtualMachine -Id $Interface.Id Write-Host "Added AzureVMNetworkInterface successfully" $VirtualMachine = Set-AzureRmVMOSDisk -VM $VirtualMachine -Name $OSDiskName -VhdUri $OSDiskUri -CreateOption FromImage Write-Host "Created AzureVMOSDisk successfully" if ($SshPubKey) { $VirtualMachine = Add-AzureRmVMSshPublicKey -VM $VirtualMachine -KeyData $SshPubKey -Path ("/home/" + $Username + "/.ssh/authorized_keys") Write-Host "Added SSH public key successfully" } ## Create the VM in Azure if(!$DryRun) { New-AzureRmVM -ResourceGroupName $ResourceGroupName -Location $Location -VM $VirtualMachine } Write-Host "Created AzureVM successfully: $VMName" $VirtualMachine = Get-AzureRmVM -ResourceGroupName $ResourceGroupName -Name $VMName Write-Host "Fetched VM successfully" if(!$DryRun) { Add-AzureRmVMDataDisk -VM $VirtualMachine -Name $DataDisk1Name -Caching None -DiskSizeInGB 1 -Lun 0 -VhdUri $DataDisk1Uri -CreateOption Empty Add-AzureRmVMDataDisk -VM $VirtualMachine -Name $DataDisk2Name -Caching None -DiskSizeInGB 1 -Lun 1 -VhdUri $DataDisk2Uri -CreateOption Empty } Write-Host "Added DataDisks successfully: $DataDisk1Name, $DataDisk2Name" if(!$DryRun) { Update-AzureRmVM -ResourceGroupName $ResourceGroupName -VM $VirtualMachine } Write-Host "Updated VM successfully" ## SSH preparation $global:Hostname = $PublicIp.DnsSettings.Fqdn.ToString() if ($SshPrivKeyPath -and !$DryRun) { $commandFileName = $ResourcePrefix + "Commands.txt" $commands = @" sudo mkdir /root/.ssh sudo cp .ssh/authorized_keys /root/.ssh/ sudo chmod 700 /root/.ssh sudo chmod 600 /root/.ssh/authorized_keys sudo restorecon -R -v /root/.ssh sudo echo "PermitRootLogin yes" >>/etc/ssh/sshd_config sudo service sshd restart exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -tt -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath ${Username}@${Hostname} <$commandFileName" Remove-Item $commandFileName Write-Host "Copied SSH public key for root" $commands = @" (cat <> /root/.bashrc parted /dev/sdc mklabel msdos mkpart pri ext2 0% 100% quit parted /dev/sdd mklabel msdos mkpart pri ext2 0% 100% quit mkfs.ext4 /dev/sdc1 mkfs.ext4 /dev/sdd1 UUID1="`$(blkid -s UUID -o value /dev/sdc1)" UUID2="`$(blkid -s UUID -o value /dev/sdd1)" echo "UUID=`$UUID1 /data1 ext4 defaults 0 0" >>/etc/fstab echo "UUID=`$UUID2 /data2 ext4 defaults 0 0" >>/etc/fstab mkdir /data1 mkdir /data2 mount -a exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName Write-Host "Mounted data partitions" $commands = @" sed -i 's/SELINUX=.*/SELINUX=disabled/g' /etc/selinux/config reboot "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName Start-Sleep 5 $vmRunning = $false while(!$vmRunning) { try { $tcpClient = New-Object System.Net.Sockets.TcpClient $tcpClient.Connect($Hostname, "22") $vmRunning = $true } catch { Write-Host "VM is not up yet" } } Write-Host "SELinux disabled" } ## Encryption if(!$DryRun) { $global:EncryptionEnableOutput = Set-AzureRmVMDiskEncryptionExtension ` -ExtensionName $ExtensionName ` -ResourceGroupName $ResourceGroupName ` -VMName $VMName ` -AadClientID $AadClientId ` -AadClientSecret $AadClientSecret ` -DiskEncryptionKeyVaultId $KeyVault.ResourceId ` -DiskEncryptionKeyVaultUrl $KeyVault.VaultUri ` -KeyEncryptionKeyVaultId $KeyVault.ResourceId ` -KeyEncryptionKeyURL $KeyEncryptionKey.Id ` -KeyEncryptionAlgorithm "RSA-OAEP" ` -VolumeType $VolumeType ` -SequenceVersion "1" ` -Force:$Force 3>&1 | Out-String Write-Host "Set AzureRmVMDiskEncryptionExtension successfully" $global:BackupTag = [regex]::match($EncryptionEnableOutput, '(AzureEnc.*?),').Groups[1].Value } ================================================ FILE: VMEncryption/Test-AzureRmVMDiskEncryptionExtensionDiskFormat.ps1 ================================================ Param( [Parameter(Mandatory=$true)] [string] $SubscriptionId, [Parameter(Mandatory=$true)] [string] $AadClientId, [Parameter(Mandatory=$true)] [string] $AadClientSecret, [Parameter(Mandatory=$true)] [string] $ResourcePrefix, [Parameter(Mandatory=$true)] [string] $Username, [Parameter(Mandatory=$true)] [string] $Password, [string] $ExtensionName="AzureDiskEncryptionForLinux", [string] $SshPubKey, [string] $SshPrivKeyPath, [string] $Location="eastus", [string] $VolumeType="data", [string] $GalleryImage="RedHat:RHEL:7.2", [string] $VMSize="Standard_D2" ) $ErrorActionPreference = "Stop" Set-AzureRmContext -SubscriptionId $SubscriptionId Write-Host "Set AzureRmContext successfully" ## Resource Group $global:ResourceGroupName = $ResourcePrefix + "ResourceGroup" New-AzureRmResourceGroup -Name $ResourceGroupName -Location $Location Write-Host "Created ResourceGroup successfully: $ResourceGroupName" ## KeyVault $global:KeyVaultName = $ResourcePrefix + "KeyVault" $global:KeyVault = New-AzureRmKeyVault -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -Location $Location Write-Host "Created KeyVault successfully: $KeyVaultName" Set-AzureRmKeyVaultAccessPolicy -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -ServicePrincipalName $AadClientId -PermissionsToKeys all -PermissionsToSecrets all Set-AzureRmKeyVaultAccessPolicy -VaultName $KeyVaultName -ResourceGroupName $ResourceGroupName -EnabledForDiskEncryption Write-Host "Set AzureRmKeyVaultAccessPolicy successfully" Add-AzureKeyVaultKey -VaultName $KeyVaultName -Name "keyencryptionkey" -Destination Software Write-Host "Added AzureRmKeyVaultKey successfully" $global:KeyEncryptionKey = Get-AzureKeyVaultKey -VaultName $KeyVault.OriginalVault.Name -Name "keyencryptionkey" Write-Host "Fetched KeyEncryptionKey successfully" ## Storage $global:StorageName = ($ResourcePrefix + "Storage").ToLower() $global:StorageType = "Standard_GRS" $global:ContainerName = "vhds" $global:StorageAccount = New-AzureRmStorageAccount -ResourceGroupName $ResourceGroupName -Name $StorageName -Type $StorageType -Location $Location Write-Host "Created StorageAccount successfully: $StorageName" ## Network $global:PublicIpName = $ResourcePrefix + "PublicIp" $global:InterfaceName = $ResourcePrefix + "NetworkInterface" $global:SubnetName = $ResourcePrefix + "Subnet" $global:VNetName = $ResourcePrefix + "VNet" $global:VNetAddressPrefix = "10.0.0.0/16" $global:VNetSubnetAddressPrefix = "10.0.0.0/24" $global:DomainNameLabel = ($ResourcePrefix + "VM").ToLower() $global:PublicIp = New-AzureRmPublicIpAddress -Name $PublicIpName -ResourceGroupName $ResourceGroupName -Location $Location -AllocationMethod Dynamic -DomainNameLabel $DomainNameLabel Write-Host "Created PublicIp successfully: " $PublicIp.DnsSettings.Fqdn.ToString() $global:SubnetConfig = New-AzureRmVirtualNetworkSubnetConfig -Name $SubnetName -AddressPrefix $VNetSubnetAddressPrefix Write-Host "Created SubnetConfig successfully: $SubnetName" $global:VNet = New-AzureRmVirtualNetwork -Name $VNetName -ResourceGroupName $ResourceGroupName -Location $Location -AddressPrefix $VNetAddressPrefix -Subnet $SubnetConfig Write-Host "Created AzureRmVirtualNetwork successfully: $VNetName" $global:Interface = New-AzureRmNetworkInterface -Name $InterfaceName -ResourceGroupName $ResourceGroupName -Location $Location -SubnetId $VNet.Subnets[0].Id -PublicIpAddressId $PublicIp.Id Write-Host "Created AzureNetworkInterface successfully: $InterfaceName" ## Compute $global:VMName = $ResourcePrefix + "VM" $global:ComputerName = $ResourcePrefix + "VM" $global:OSDiskName = $VMName + "OsDisk" $global:OSDiskUri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $OSDiskName + ".vhd" $global:DataDisk1Name = $VMName + "DataDisk1" $global:DataDisk1Uri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $DataDisk1Name + ".vhd" $global:DataDisk2Name = $VMName + "DataDisk2" $global:DataDisk2Uri = $StorageAccount.PrimaryEndpoints.Blob.ToString() + "vhds/" + $DataDisk2Name + ".vhd" ## Setup local VM object $SecString = ($Password | ConvertTo-SecureString -AsPlainText -Force) $Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList @($Username, $SecString) Write-Host "Created credentials successfully" $global:VirtualMachine = New-AzureRmVMConfig -VMName $VMName -VMSize $VMSize Write-Host "Created AzureRmVMConfig successfully" $VirtualMachine = Set-AzureRmVMOperatingSystem -VM $VirtualMachine -Linux -ComputerName $ComputerName -Credential $Credential Write-Host "Set AzureRmVMOperatingSystem successfully" $PublisherName = $GalleryImage.Split(":")[0] $Offer = $GalleryImage.Split(":")[1] $Skus = $GalleryImage.Split(":")[2] Write-Host "PublisherName: $PublisherName, Offer: $Offer, Skus: $Skus" $VirtualMachine = Set-AzureRmVMSourceImage -VM $VirtualMachine -PublisherName $PublisherName -Offer $Offer -Skus $Skus -Version "latest" Write-Host "Set AzureVMSourceImage successfully" $VirtualMachine = Add-AzureRmVMNetworkInterface -VM $VirtualMachine -Id $Interface.Id Write-Host "Added AzureVMNetworkInterface successfully" $VirtualMachine = Set-AzureRmVMOSDisk -VM $VirtualMachine -Name $OSDiskName -VhdUri $OSDiskUri -CreateOption FromImage Write-Host "Created AzureVMOSDisk successfully" if ($SshPubKey) { $VirtualMachine = Add-AzureRmVMSshPublicKey -VM $VirtualMachine -KeyData $SshPubKey -Path ("/home/" + $Username + "/.ssh/authorized_keys") Write-Host "Added SSH public key successfully" } ## Create the VM in Azure New-AzureRmVM -ResourceGroupName $ResourceGroupName -Location $Location -VM $VirtualMachine Write-Host "Created AzureVM successfully: $VMName" $VirtualMachine = Get-AzureRmVM -ResourceGroupName $ResourceGroupName -Name $VMName Write-Host "Fetched VM successfully" Add-AzureRmVMDataDisk -VM $VirtualMachine -Name $DataDisk1Name -Caching None -DiskSizeInGB 10 -Lun 0 -VhdUri $DataDisk1Uri -CreateOption Empty Add-AzureRmVMDataDisk -VM $VirtualMachine -Name $DataDisk2Name -Caching None -DiskSizeInGB 10 -Lun 1 -VhdUri $DataDisk2Uri -CreateOption Empty Write-Host "Added DataDisks successfully: $DataDisk1Name, $DataDisk2Name" Update-AzureRmVM -ResourceGroupName $ResourceGroupName -VM $VirtualMachine Write-Host "Updated VM successfully" ## SSH preparation if ($SshPrivKeyPath) { $global:Hostname = $PublicIp.DnsSettings.Fqdn.ToString() $commandFileName = $ResourcePrefix + "Commands.txt" $commands = @" sudo mkdir /root/.ssh sudo cp .ssh/authorized_keys /root/.ssh/ sudo chmod 700 /root/.ssh sudo chmod 600 /root/.ssh/authorized_keys sudo restorecon -R -v /root/.ssh sudo echo "PermitRootLogin yes" >>/etc/ssh/sshd_config sudo service sshd restart exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -tt -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath ${Username}@${Hostname} <$commandFileName" Remove-Item $commandFileName Write-Host "Copied SSH public key for root" $commands = @" (cat <> /root/.bashrc apt-get install -yq mdadm yum install -y mdadm exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName Write-Host "Installed mdadm" $commands = @" mdadm --create --verbose /dev/md0 --level=0 --raid-devices=2 /dev/sdc /dev/sdd mkdir -p /etc/mdadm mdadm --detail --scan > /etc/mdadm/mdadm.conf exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName Write-Host "Created RAID array" $commands = @" sed -i 's/SELINUX=.*/SELINUX=disabled/g' /etc/selinux/config reboot "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName Start-Sleep 5 $vmRunning = $false while(!$vmRunning) { try { $tcpClient = New-Object System.Net.Sockets.TcpClient $tcpClient.Connect($Hostname, "22") $vmRunning = $true } catch { Write-Host "VM is not up yet" } } Write-Host "SELinux disabled" $commands = @" lsblk exit "@ $commands | Out-File -Encoding ascii $commandFileName dos2unix $commandFileName $stdout = cmd /c "ssh -o UserKnownHostsFile=C:\Windows\System32\NUL -o StrictHostKeyChecking=no -i $SshPrivKeyPath root@${Hostname} <$commandFileName" Remove-Item $commandFileName $global:RaidBlockDevice = "/dev/" + [regex]::Match($stdout, '(md\d+)').Captures.Groups[0].Value Write-Host "Encrypting RAID device: $RaidBlockDevice" } ## Encryption Read-Host "Press Enter to continue..." $global:Settings = @{ "AADClientID" = $AadClientId; "DiskFormatQuery" = "[{`"dev_path`":`"$RaidBlockDevice`",`"file_system`":`"ext4`",`"name`":`"encryptedraid`"}]"; "EncryptionOperation" = "EnableEncryptionFormat"; "KeyEncryptionAlgorithm" = "RSA-OAEP"; "KeyEncryptionKeyURL" = $KeyEncryptionKey.Id; "KeyVaultURL" = $KeyVault.VaultUri; "SequenceVersion" = "1"; "VolumeType" = $VolumeType; } $global:ProtectedSettings = @{ "AADClientSecret" = $AadClientSecret; } Set-AzureRmVMExtension ` -ResourceGroupName $ResourceGroupName ` -Location $Location ` -VMName $VMName ` -Name $ExtensionName ` -Publisher "Microsoft.Azure.Security" ` -Type "AzureDiskEncryptionForLinux" ` -TypeHandlerVersion "0.1" ` -Settings $Settings ` -ProtectedSettings $ProtectedSettings Write-Host "Set AzureRmVMExtension successfully" $VirtualMachine = Get-AzureRmVM -ResourceGroupName $ResourceGroupName -Name $VMName $global:InstanceView = Get-AzureRmVM -ResourceGroupName $ResourceGroupName -Name $VMName -Status $KVSecretRef = New-Object Microsoft.Azure.Management.Compute.Models.KeyVaultSecretReference -ArgumentList @($InstanceView.Extensions[0].Statuses[0].Message, $KeyVault.ResourceId) $KVKeyRef = New-Object Microsoft.Azure.Management.Compute.Models.KeyVaultKeyReference -ArgumentList @($KeyEncryptionKey.Id, $KeyVault.ResourceId) $VirtualMachine.StorageProfile.OsDisk.EncryptionSettings = New-Object Microsoft.Azure.Management.Compute.Models.DiskEncryptionSettings -ArgumentList @($KVSecretRef, $KVKeyRef, $true) Update-AzureRmVM -ResourceGroupName $ResourceGroupName -VM $VirtualMachine Write-Host "Updated VM successfully" ================================================ FILE: VMEncryption/VMEncryption.pyproj ================================================  Debug 2.0 334deedb-1c9a-40c8-89f2-a4ae042c18aa . . . VMEncryption VMEncryption true false true false 10.0 $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets Code Code Code Code Code Code Code Code Code Code Code ================================================ FILE: VMEncryption/extension_shim.sh ================================================ #!/usr/bin/env bash # Keeping the default command COMMAND="" PYTHON="" USAGE="$(basename "$0") [-h] [-i|--install] [-u|--uninstall] [-d|--disable] [-e|--enable] [-p|--update] [-m|--daemon] Program to find the installed python on the box and invoke a Python extension script using Python 2.7. where: -h|--help show this help text -i|--install install the extension -u|--uninstall uninstall the extension -d|--disable disable the extension -e|--enable enable the extension -p|--update update the extension -m|--daemon invoke daemon option -c|--command command to run example: # Install usage $ bash extension_shim.sh -i python ./main/handle.py -install # Custom executable python file $ bash extension_shim.sh -c ""hello.py"" -i python hello.py -install # Custom executable python file with arguments $ bash extension_shim.sh -c ""hello.py --install"" python hello.py --install " function find_python(){ local python_exec_command=$1 # Check if there is python defined. if command -v python >/dev/null 2>&1 ; then eval ${python_exec_command}="python" fi } # Transform long options to short ones for getopts support (getopts doesn't support long args) for arg in "$@"; do shift case "$arg" in "--help") set -- "$@" "-h" ;; "--install") set -- "$@" "-i" ;; "--update") set -- "$@" "-p" ;; "--enable") set -- "$@" "-e" ;; "--disable") set -- "$@" "-d" ;; "--uninstall") set -- "$@" "-u" ;; "--daemon") set -- "$@" "-m" ;; *) set -- "$@" "$arg" esac done if [ -z "$arg" ] then echo "$USAGE" >&2 exit 1 fi # Get the arguments while getopts "iudephc:?" o; do case "${o}" in h|\?) echo "$USAGE" exit 0 ;; i) operation="-install" ;; u) operation="-uninstall" ;; d) operation="-disable" ;; e) operation="-enable" ;; p) operation="-update" ;; m) operation="-daemon" ;; c) COMMAND="$OPTARG" ;; *) echo "$USAGE" >&2 exit 1 ;; esac done shift "$((OPTIND-1))" # If find_python is not able to find a python installed, $PYTHON will be null. find_python PYTHON if [ -z "$PYTHON" ]; then echo "No Python interpreter found on the box" >&2 exit 51 # Not Supported else PYTHON_VER=`${PYTHON} --version 2>&1` if [[ "$PYTHON_VER" =~ "Python 2.6" ]] || [[ "$PYTHON_VER" =~ "Python 2.7" ]]; then echo $PYTHON_VER else echo "Expected Python 2.7, found $PYTHON_VER" >&2 exit 51 # Not Supported fi fi ${PYTHON} ${COMMAND} ${operation} 2>&1 # DONE ================================================ FILE: VMEncryption/main/BackupLogger.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import datetime import traceback import urlparse import httplib import os import string class BackupLogger(object): def __init__(self, hutil): self.hutil = hutil self.current_process_id = os.getpid() """description of class""" def log(self, msg, level='Info'): log_msg = "{0}: [{1}] {2}".format(self.current_process_id, level, msg) log_msg = filter(lambda c: c in string.printable, log_msg) log_msg = log_msg.encode('ascii', 'ignore') self.hutil.log(log_msg) self.log_to_console(log_msg) def log_to_console(self, msg): try: with open('/dev/console', 'w') as f: msg = filter(lambda c: c in string.printable, msg) f.write('[AzureDiskEncryption] ' + msg + '\n') except IOError as e: pass ================================================ FILE: VMEncryption/main/BekUtil.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from Common import TestHooks import base64 import os.path """ add retry-logic to the network api call. """ class BekUtil(object): """ Utility functions related to the BEK VOLUME and BEK files """ def __init__(self, disk_util, logger): self.disk_util = disk_util self.logger = logger self.bek_filesystem_mount_point = '/mnt/azure_bek_disk' def generate_passphrase(self, algorithm): if TestHooks.use_hard_code_passphrase: return TestHooks.hard_code_passphrase else: with open("/dev/urandom", "rb") as _random_source: bytes = _random_source.read(127) passphrase_generated = base64.b64encode(bytes) return passphrase_generated def get_bek_passphrase_file(self, encryption_config): """ Returns the LinuxPassPhraseFileName path """ bek_filename = encryption_config.get_bek_filename() try: self.disk_util.make_sure_path_exists(self.bek_filesystem_mount_point) self.disk_util.mount_bek_volume("BEK VOLUME", self.bek_filesystem_mount_point, "fmask=077") if os.path.exists(os.path.join(self.bek_filesystem_mount_point, bek_filename)): return os.path.join(self.bek_filesystem_mount_point, bek_filename) except Exception as e: message = "Failed to get BEK from BEK VOLUME with error: {0}".format(str(e)) self.logger.log(message) return None def umount_azure_passhprase(self, encryption_config, force=False): passphrase_file = self.get_bek_passphrase_file(encryption_config) if force or (passphrase_file and os.path.exists(passphrase_file)): self.disk_util.umount(self.bek_filesystem_mount_point) ================================================ FILE: VMEncryption/main/CommandExecutor.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import os.path import shlex import sys from subprocess import * from threading import Timer class ProcessCommunicator(object): def __init__(self): self.stdout = None self.stderr = None class CommandExecutor(object): """description of class""" def __init__(self, logger): self.logger = logger def Execute(self, command_to_execute, raise_exception_on_failure=False, communicator=None, input=None, suppress_logging=False, timeout=0): if type(command_to_execute) == unicode: command_to_execute = command_to_execute.encode('ascii', 'ignore') if not suppress_logging: self.logger.log("Executing: {0}".format(command_to_execute)) args = shlex.split(command_to_execute) proc = None timer = None return_code = None try: proc = Popen(args, stdout=PIPE, stderr=PIPE, stdin=PIPE, close_fds=True) except Exception as e: if raise_exception_on_failure: raise else: if not suppress_logging: self.logger.log("Process creation failed: " + str(e)) return -1 def timeout_process(): proc.kill() self.logger.log("Command {0} didn't finish in {1} seconds. Timing it out".format(command_to_execute, timeout)) try: if timeout>0: timer = Timer(timeout, timeout_process) timer.start() stdout, stderr = proc.communicate(input=input) finally: if timer is not None: timer.cancel() return_code = proc.returncode if isinstance(communicator, ProcessCommunicator): communicator.stdout, communicator.stderr = stdout, stderr if int(return_code) != 0: msg = "Command {0} failed with return code {1}".format(command_to_execute, return_code) msg += "\nstdout:\n" + stdout msg += "\nstderr:\n" + stderr if not suppress_logging: self.logger.log(msg) if raise_exception_on_failure: raise Exception(msg) return return_code def ExecuteInBash(self, command_to_execute, raise_exception_on_failure=False, communicator=None, input=None, suppress_logging=False): command_to_execute = 'bash -c "{0}{1}"'.format('set -e; ' if raise_exception_on_failure else '', command_to_execute) return self.Execute(command_to_execute, raise_exception_on_failure, communicator, input, suppress_logging) ================================================ FILE: VMEncryption/main/Common.py ================================================ #!/usr/bin/env python # # Azure Disk Encryption For Linux Extension # # Copyright 2019 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class CommonVariables: utils_path_name = 'Utils' extension_name = 'AzureDiskEncryptionForLinux' extension_version = '0.1.0.999345' extension_type = extension_name extension_media_link = 'https://amextpaas.blob.core.windows.net/prod/' + extension_name + '-' + str(extension_version) + '.zip' extension_label = 'Windows Azure VMEncryption Extension for Linux IaaS' extension_description = extension_label extension_shim_filename = "extension_shim.sh" """ disk/file system related """ sector_size = 512 luks_header_size = 4096 * 512 default_block_size = 52428800 min_filesystem_size_support = 52428800 * 3 #TODO for the sles 11, we should use the ext3 default_file_system = 'ext4' format_supported_file_systems = ['ext4', 'ext3', 'ext2', 'xfs', 'btrfs'] inplace_supported_file_systems = ['ext4', 'ext3', 'ext2'] default_mount_name = 'encrypted_disk' dev_mapper_root = '/dev/mapper/' osmapper_name = 'osencrypt' azure_symlinks_dir = '/dev/disk/azure' disk_by_id_root = '/dev/disk/by-id' disk_by_uuid_root = '/dev/disk/by-uuid' encryption_key_mount_point = '/mnt/azure_bek_disk/' bek_fstab_line_template = 'LABEL=BEK\\040VOLUME {0} auto defaults,discard,nofail 0 0\n' bek_fstab_line_template_ubuntu_14 = 'LABEL=BEK\\040VOLUME {0} auto defaults,discard,nobootwait 0 0\n' etc_defaults_cryptdisks_line = '\nCRYPTDISKS_MOUNT="$CRYPTDISKS_MOUNT {0}"\n' """ parameter key names """ PassphraseFileNameKey = 'BekFileName' KeyEncryptionKeyURLKey = 'KeyEncryptionKeyURL' KeyVaultURLKey = 'KeyVaultURL' AADClientIDKey = 'AADClientID' AADClientCertThumbprintKey = 'AADClientCertThumbprint' KeyEncryptionAlgorithmKey = 'KeyEncryptionAlgorithm' encryption_algorithms = ['RSA-OAEP', 'RSA-OAEP-256', 'RSA1_5'] default_encryption_algorithm = 'RSA-OAEP' DiskFormatQuerykey = "DiskFormatQuery" PassphraseKey = 'Passphrase' """ value for VolumeType could be OS or Data """ VolumeTypeKey = 'VolumeType' AADClientSecretKey = 'AADClientSecret' SecretUriKey = 'SecretUri' SecretSeqNum = 'SecretSeqNum' VolumeTypeOS = 'OS' VolumeTypeData = 'Data' VolumeTypeAll = 'All' SupportedVolumeTypes = [ VolumeTypeOS, VolumeTypeData, VolumeTypeAll ] """ command types """ EnableEncryption = 'EnableEncryption' EnableEncryptionFormat = 'EnableEncryptionFormat' EnableEncryptionFormatAll = 'EnableEncryptionFormatAll' UpdateEncryptionSettings = 'UpdateEncryptionSettings' DisableEncryption = 'DisableEncryption' QueryEncryptionStatus = 'QueryEncryptionStatus' """ encryption config keys """ EncryptionEncryptionOperationKey = 'EncryptionOperation' EncryptionDecryptionOperationKey = 'DecryptionOperation' EncryptionVolumeTypeKey = 'VolumeType' EncryptionDiskFormatQueryKey = 'DiskFormatQuery' """ crypt ongoing item config keys """ OngoingItemMapperNameKey = 'MapperName' OngoingItemHeaderFilePathKey = 'HeaderFilePath' OngoingItemOriginalDevNamePathKey = 'DevNamePath' OngoingItemOriginalDevPathKey = 'DevicePath' OngoingItemPhaseKey = 'Phase' OngoingItemHeaderSliceFilePathKey = 'HeaderSliceFilePath' OngoingItemFileSystemKey = 'FileSystem' OngoingItemMountPointKey = 'MountPoint' OngoingItemDeviceSizeKey = 'Size' OngoingItemCurrentSliceIndexKey = 'CurrentSliceIndex' OngoingItemFromEndKey = 'FromEnd' OngoingItemCurrentDestinationKey = 'CurrentDestination' OngoingItemCurrentTotalCopySizeKey = 'CurrentTotalCopySize' OngoingItemCurrentLuksHeaderFilePathKey = 'CurrentLuksHeaderFilePath' OngoingItemCurrentSourcePathKey = 'CurrentSourcePath' OngoingItemCurrentBlockSizeKey = 'CurrentBlockSize' """ encryption phase devinitions """ EncryptionPhaseBackupHeader = 'BackupHeader' EncryptionPhaseCopyData = 'EncryptingData' EncryptionPhaseRecoverHeader = 'RecoverHeader' EncryptionPhaseEncryptDevice = 'EncryptDevice' EncryptionPhaseDone = 'Done' """ decryption phase constants """ DecryptionPhaseCopyData = 'DecryptingData' DecryptionPhaseDone = 'Done' """ logs related """ InfoLevel = 'Info' WarningLevel = 'Warning' ErrorLevel = 'Error' """ error codes """ extension_success_status = 'success' extension_error_status = 'error' process_success = 0 success = 0 os_not_supported = 51 missing_dependency = 52 configuration_error = 53 luks_format_error = 2 scsi_number_not_found = 3 device_not_blank = 4 environment_error = 5 luks_open_error = 6 mkfs_error = 7 folder_conflict_error = 8 mount_error = 9 mount_point_not_exists = 10 passphrase_too_long_or_none = 11 parameter_error = 12 create_encryption_secret_failed = 13 encrypttion_already_enabled = 14 passphrase_file_not_found = 15 command_not_support = 16 volue_type_not_support = 17 copy_data_error = 18 encryption_failed = 19 tmpfs_error = 20 backup_slice_file_error = 21 unmount_oldroot_error = 22 operation_lookback_failed = 23 unknown_error = 100 class TestHooks: search_not_only_ide = False use_hard_code_passphrase = False hard_code_passphrase = "Quattro!" class DeviceItem(object): def __init__(self): #NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE,MAJ:MIN self.name = None self.type = None self.file_system = None self.mount_point = None self.label = None self.uuid = None self.model = None self.size = None self.majmin = None self.device_id = None self.azure_name = None def __str__(self): return ("name:" + str(self.name) + " type:" + str(self.type) + " fstype:" + str(self.file_system) + " mountpoint:" + str(self.mount_point) + " label:" + str(self.label) + " model:" + str(self.model) + " size:" + str(self.size) + " majmin:" + str(self.majmin) + " device_id:" + str(self.device_id)) + " azure_name:" + str(self.azure_name) class LvmItem(object): def __init__(self): #lv_name,vg_name,lv_kernel_major,lv_kernel_minor self.lv_name = None self.vg_name = None self.lv_kernel_major = None self.lv_kernel_minor = None def __str__(self): return ("lv_name:" + str(self.lv_name) + " vg_name:" + str(self.vg_name) + " lv_kernel_major:" + str(self.lv_kernel_major) + " lv_kernel_minor:" + str(self.lv_kernel_minor)) class CryptItem(object): def __init__(self): self.mapper_name = None self.dev_path = None self.mount_point = None self.file_system = None self.luks_header_path = None self.uses_cleartext_key = None self.current_luks_slot = None def __str__(self): return ("name: " + str(self.mapper_name) + " dev_path:" + str(self.dev_path) + " mount_point:" + str(self.mount_point) + " file_system:" + str(self.file_system) + " luks_header_path:" + str(self.luks_header_path) + " uses_cleartext_key:" + str(self.uses_cleartext_key) + " current_luks_slot:" + str(self.current_luks_slot)) def __eq__(self, other): """ Override method for "==" operation, useful for making CryptItem comparison a little logically consistent For example a luks_slot value of "-1" and "None" are logically equivalent, so this method, treats them the same This is done by "consolidating" both values to "None". """ if not isinstance(other, CryptItem): return NotImplemented def _consolidate_luks_header_path(crypt_item): """ if luks_header_path is absent, then it implies that the header is attached so the header path might as well be the device path (dev_path) """ if crypt_item.luks_header_path and not crypt_item.luks_header_path == "None": return crypt_item.luks_header_path return crypt_item.dev_path def _consolidate_luks_slot(crypt_item): """ -1 for luks_slot implies "None" """ if crypt_item.current_luks_slot == -1: return None return crypt_item.current_luks_slot def _consolidate_file_system(crypt_item): """ "None" and "auto" are functionally identical for "file_system" field """ if not crypt_item.file_system or crypt_item.file_system == "None": return "auto" return crypt_item.file_system def _consolidate_cleartext_key(crypt_item): """ "False", "None", "" and None are equivalent to False """ if not crypt_item.uses_cleartext_key or crypt_item.uses_cleartext_key in ["False", "None"]: return False return True return self.mapper_name == other.mapper_name and\ self.dev_path == other.dev_path and\ self.file_system == other.file_system and\ self.mount_point == other.mount_point and\ _consolidate_luks_header_path(self) == _consolidate_luks_header_path(other) and \ _consolidate_luks_slot(self) == _consolidate_luks_slot(other) and\ _consolidate_file_system(self) == _consolidate_file_system(other) and\ _consolidate_cleartext_key(self) == _consolidate_cleartext_key(other) ================================================ FILE: VMEncryption/main/ConfigUtil.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os.path from Common import * from ConfigParser import * class ConfigKeyValuePair(object): def __init__(self, prop_name, prop_value): self.prop_name = prop_name self.prop_value = prop_value class ConfigUtil(object): def __init__(self, config_file_path, section_name, logger): """ this should not create the config file with path: config_file_path """ self.config_file_path = config_file_path self.logger = logger self.azure_crypt_config_section = section_name def config_file_exists(self): return os.path.exists(self.config_file_path) def save_config(self, prop_name, prop_value): #TODO make the operation an transaction. config = ConfigParser() if os.path.exists(self.config_file_path): config.read(self.config_file_path) # read values from a section if not config.has_section(self.azure_crypt_config_section): config.add_section(self.azure_crypt_config_section) config.set(self.azure_crypt_config_section, prop_name, prop_value) with open(self.config_file_path, 'wb') as configfile: config.write(configfile) def save_configs(self, key_value_pairs): config = ConfigParser() if os.path.exists(self.config_file_path): config.read(self.config_file_path) # read values from a section if not config.has_section(self.azure_crypt_config_section): config.add_section(self.azure_crypt_config_section) for key_value_pair in key_value_pairs: if key_value_pair.prop_value is not None: config.set(self.azure_crypt_config_section, key_value_pair.prop_name, key_value_pair.prop_value) with open(self.config_file_path, 'wb') as configfile: config.write(configfile) def get_config(self, prop_name): # write the configs, the bek file name and so on. if os.path.exists(self.config_file_path): try: config = ConfigParser() config.read(self.config_file_path) # read values from a section prop_value = config.get(self.azure_crypt_config_section, prop_name) return prop_value except (NoSectionError, NoOptionError) as e: self.logger.log(msg="value of prop_name:{0} not found.".format(prop_name)) return None else: self.logger.log("the config file {0} not exists.".format(self.config_file_path)) return None ================================================ FILE: VMEncryption/main/DecryptionMarkConfig.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import traceback from ConfigUtil import * from Common import CommonVariables class DecryptionMarkConfig(object): def __init__(self, logger, encryption_environment): self.logger = logger self.encryption_environment = encryption_environment self.command = None self.volume_type = None self.decryption_mark_config = ConfigUtil(self.encryption_environment.azure_decrypt_request_queue_path, 'decryption_request_queue', self.logger) def get_current_command(self): return self.decryption_mark_config.get_config(CommonVariables.EncryptionEncryptionOperationKey) def config_file_exists(self): return self.decryption_mark_config.config_file_exists() def commit(self): key_value_pairs = [] command = ConfigKeyValuePair(CommonVariables.EncryptionEncryptionOperationKey, self.command) key_value_pairs.append(command) volume_type = ConfigKeyValuePair(CommonVariables.EncryptionVolumeTypeKey, self.volume_type) key_value_pairs.append(volume_type) self.decryption_mark_config.save_configs(key_value_pairs) def clear_config(self): try: if os.path.exists(self.encryption_environment.azure_decrypt_request_queue_path): os.remove(self.encryption_environment.azure_decrypt_request_queue_path) return True except OSError as e: self.logger.log("Failed to clear_queue with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return False ================================================ FILE: VMEncryption/main/DiskUtil.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import json import os import os.path import re from subprocess import Popen import shutil import traceback import uuid import glob from datetime import datetime from EncryptionConfig import EncryptionConfig from DecryptionMarkConfig import DecryptionMarkConfig from EncryptionMarkConfig import EncryptionMarkConfig from TransactionalCopyTask import TransactionalCopyTask from CommandExecutor import CommandExecutor, ProcessCommunicator from Common import CommonVariables, CryptItem, LvmItem, DeviceItem class DiskUtil(object): os_disk_lvm = None sles_cache = {} device_id_cache = {} def __init__(self, hutil, patching, logger, encryption_environment): self.encryption_environment = encryption_environment self.hutil = hutil self.distro_patcher = patching self.logger = logger self.ide_class_id = "{32412632-86cb-44a2-9b5c-50d1417354f5}" self.vmbus_sys_path = '/sys/bus/vmbus/devices' self.command_executor = CommandExecutor(self.logger) def copy(self, ongoing_item_config, status_prefix=''): copy_task = TransactionalCopyTask(logger=self.logger, disk_util=self, hutil=self.hutil, ongoing_item_config=ongoing_item_config, patching=self.distro_patcher, encryption_environment=self.encryption_environment, status_prefix=status_prefix) try: mem_fs_result = copy_task.prepare_mem_fs() if mem_fs_result != CommonVariables.process_success: return CommonVariables.tmpfs_error else: return copy_task.begin_copy() except Exception as e: message = "Failed to perform dd copy: {0}, stack trace: {1}".format(e, traceback.format_exc()) self.logger.log(msg=message, level=CommonVariables.ErrorLevel) finally: copy_task.clear_mem_fs() def format_disk(self, dev_path, file_system): mkfs_command = "" if file_system in CommonVariables.format_supported_file_systems: mkfs_command = "mkfs." + file_system mkfs_cmd = "{0} {1}".format(mkfs_command, dev_path) return self.command_executor.Execute(mkfs_cmd) def make_sure_path_exists(self, path): mkdir_cmd = self.distro_patcher.mkdir_path + ' -p ' + path self.logger.log("make sure path exists, executing: {0}".format(mkdir_cmd)) return self.command_executor.Execute(mkdir_cmd) def touch_file(self, path): mkdir_cmd = self.distro_patcher.touch_path + ' ' + path self.logger.log("touching file, executing: {0}".format(mkdir_cmd)) return self.command_executor.Execute(mkdir_cmd) def parse_crypttab_line(self, line): crypttab_parts = line.strip().split() if len(crypttab_parts) < 3: # Line should have enough content return None if crypttab_parts[0].startswith("#"): # Line should not be a comment return None crypt_item = CryptItem() crypt_item.mapper_name = crypttab_parts[0] crypt_item.dev_path = crypttab_parts[1] keyfile_path = crypttab_parts[2] if CommonVariables.encryption_key_mount_point not in keyfile_path and self.encryption_environment.cleartext_key_base_path not in keyfile_path: return None # if the key_file path doesn't have the encryption key file name, its probably not for us to mess with if self.encryption_environment.cleartext_key_base_path in keyfile_path: crypt_item.uses_cleartext_key = True crypttab_option_string = crypttab_parts[3] crypttab_options = crypttab_option_string.split(',') for option in crypttab_options: option_pair = option.split("=") if len(option_pair) == 2: key = option_pair[0].strip() value = option_pair[1].strip() if key == "header": crypt_item.luks_header_path = value return crypt_item def parse_azure_crypt_mount_line(self, line): crypt_item = CryptItem() crypt_mount_item_properties = line.strip().split() crypt_item.mapper_name = crypt_mount_item_properties[0] crypt_item.dev_path = crypt_mount_item_properties[1] crypt_item.luks_header_path = crypt_mount_item_properties[2] if crypt_mount_item_properties[2] and crypt_mount_item_properties[2] != "None" else None crypt_item.mount_point = crypt_mount_item_properties[3] crypt_item.file_system = crypt_mount_item_properties[4] crypt_item.uses_cleartext_key = True if crypt_mount_item_properties[5] == "True" else False crypt_item.current_luks_slot = int(crypt_mount_item_properties[6]) if len(crypt_mount_item_properties) > 6 else -1 return crypt_item def get_crypt_items(self): crypt_items = [] rootfs_crypt_item_found = False if self.should_use_azure_crypt_mount(): with open(self.encryption_environment.azure_crypt_mount_config_path, 'r') as f: for line in f.readlines(): if not line.strip(): continue crypt_item = self.parse_azure_crypt_mount_line(line) if crypt_item.mount_point == "/" or crypt_item.mapper_name == CommonVariables.osmapper_name: rootfs_crypt_item_found = True crypt_items.append(crypt_item) else: self.logger.log("Using crypttab instead of azure_crypt_mount file.") crypttab_path = "/etc/crypttab" fstab_items = [] with open("/etc/fstab", "r") as f: for line in f.readlines(): fstab_device, fstab_mount_point = self.parse_fstab_line(line) if fstab_device is not None: fstab_items.append((fstab_device, fstab_mount_point)) if not os.path.exists(crypttab_path): self.logger.log("{0} does not exist".format(crypttab_path)) else: with open(crypttab_path, 'r') as f: for line in f.readlines(): if not line.strip(): continue crypt_item = self.parse_crypttab_line(line) if crypt_item is None: continue if crypt_item.mapper_name == CommonVariables.osmapper_name: rootfs_crypt_item_found = True for device_path, mount_path in fstab_items: if crypt_item.mapper_name in device_path: crypt_item.mount_point = mount_path crypt_items.append(crypt_item) encryption_status = json.loads(self.get_encryption_status()) if encryption_status["os"] == "Encrypted" and not rootfs_crypt_item_found: crypt_item = CryptItem() crypt_item.mapper_name = CommonVariables.osmapper_name proc_comm = ProcessCommunicator() grep_result = self.command_executor.ExecuteInBash("cryptsetup status {0} | grep device:".format(crypt_item.mapper_name), communicator=proc_comm) if grep_result == 0: crypt_item.dev_path = proc_comm.stdout.strip().split()[1] else: proc_comm = ProcessCommunicator() self.command_executor.Execute("dmsetup table --target crypt", communicator=proc_comm) for line in proc_comm.stdout.splitlines(): if crypt_item.mapper_name in line: majmin = filter(lambda p: re.match(r'\d+:\d+', p), line.split())[0] src_device = filter(lambda d: d.majmin == majmin, self.get_device_items(None))[0] crypt_item.dev_path = '/dev/' + src_device.name break rootfs_dev = next((m for m in self.get_mount_items() if m["dest"] == "/")) crypt_item.file_system = rootfs_dev["fs"] if not crypt_item.dev_path: raise Exception("Could not locate block device for rootfs") crypt_item.luks_header_path = "/boot/luks/osluksheader" if not os.path.exists(crypt_item.luks_header_path): crypt_item.luks_header_path = crypt_item.dev_path crypt_item.mount_point = "/" crypt_item.uses_cleartext_key = False crypt_item.current_luks_slot = -1 crypt_items.append(crypt_item) return crypt_items def should_use_azure_crypt_mount(self): if not os.path.exists(self.encryption_environment.azure_crypt_mount_config_path): return False non_os_entry_found = False with open(self.encryption_environment.azure_crypt_mount_config_path, 'r') as f: for line in f.readlines(): if not line.strip(): continue parsed_crypt_item = self.parse_azure_crypt_mount_line(line) if parsed_crypt_item.mapper_name != CommonVariables.osmapper_name: non_os_entry_found = True # if there is a non_os_entry found we should use azure_crypt_mount. Otherwise we shouldn't return non_os_entry_found def add_crypt_item(self, crypt_item, key_file_path): if self.should_use_azure_crypt_mount(): return self.add_crypt_item_to_azure_crypt_mount(crypt_item) else: return self.add_crypt_item_to_crypttab(crypt_item, key_file_path) def add_crypt_item_to_crypttab(self, crypt_item, key_file): if key_file is None and crypt_item.uses_cleartext_key: line_key_file = self.encryption_environment.cleartext_key_base_path + crypt_item.mapper_name else: line_key_file = key_file crypttab_line = "\n{0} {1} {2} luks,nofail".format(crypt_item.mapper_name, crypt_item.dev_path, line_key_file) if crypt_item.luks_header_path: crypttab_line += ",header=" + crypt_item.luks_header_path with open("/etc/crypttab", "a") as wf: wf.write(crypttab_line + "\n") return True def add_crypt_item_to_azure_crypt_mount(self, crypt_item): """ TODO we should judge that the second time. format is like this: """ try: if not crypt_item.luks_header_path: crypt_item.luks_header_path = "None" mount_content_item = (crypt_item.mapper_name + " " + crypt_item.dev_path + " " + crypt_item.luks_header_path + " " + crypt_item.mount_point + " " + crypt_item.file_system + " " + str(crypt_item.uses_cleartext_key) + " " + str(crypt_item.current_luks_slot)) if os.path.exists(self.encryption_environment.azure_crypt_mount_config_path): with open(self.encryption_environment.azure_crypt_mount_config_path, 'r') as f: existing_content = f.read() if existing_content is not None and existing_content.strip() != "": new_mount_content = existing_content + "\n" + mount_content_item else: new_mount_content = mount_content_item else: new_mount_content = mount_content_item with open(self.encryption_environment.azure_crypt_mount_config_path, 'w') as wf: wf.write('\n') wf.write(new_mount_content) wf.write('\n') return True except Exception: return False def remove_crypt_item(self, crypt_item): try: if self.should_use_azure_crypt_mount(): crypt_file_path = self.encryption_environment.azure_crypt_mount_config_path crypt_line_parser = self.parse_azure_crypt_mount_line elif os.path.exists("/etc/crypttab"): crypt_file_path = "/etc/crypttab" crypt_line_parser = self.parse_crypttab_line else: return True filtered_mount_lines = [] with open(crypt_file_path, 'r') as f: self.logger.log("removing an entry from {0}".format(crypt_file_path)) for line in f: if not line.strip(): continue parsed_crypt_item = crypt_line_parser(line) if parsed_crypt_item is not None and parsed_crypt_item.mapper_name == crypt_item.mapper_name: self.logger.log("Removing crypt mount entry: {0}".format(line)) continue filtered_mount_lines.append(line) with open(crypt_file_path, 'w') as wf: wf.write(''.join(filtered_mount_lines)) return True except Exception as e: return False def update_crypt_item(self, crypt_item, key_file_path): self.logger.log("Updating entry for crypt item {0}".format(crypt_item)) self.remove_crypt_item(crypt_item) self.add_crypt_item(crypt_item, key_file_path) def migrate_crypt_items(self, passphrase_file): crypt_items = self.get_crypt_items() # Archive azure_crypt_mount file try: if os.path.exists(self.encryption_environment.azure_crypt_mount_config_path): self.logger.log(msg="archiving azure crypt mount file: {0}".format(self.encryption_environment.azure_crypt_mount_config_path)) time_stamp = datetime.now() new_name = "{0}_{1}".format(self.encryption_environment.azure_crypt_mount_config_path, time_stamp) os.rename(self.encryption_environment.azure_crypt_mount_config_path, new_name) else: self.logger.log(msg=("the azure crypt mount file not exist: {0}".format(self.encryption_environment.azure_crypt_mount_config_path)), level=CommonVariables.InfoLevel) except OSError as e: self.logger.log("Failed to archive encryption mount file with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) for crypt_item in crypt_items: self.logger.log("Migrating crypt item: {0}".format(crypt_item)) if crypt_item.mount_point == "/" or CommonVariables.osmapper_name == crypt_item.mapper_name: self.logger.log("Skipping OS disk") continue if crypt_item.mount_point and crypt_item.mount_point != "None": self.logger.log(msg="restoring entry for {0} drive in fstab".format(crypt_item.mount_point), level=CommonVariables.InfoLevel) self.restore_mount_info(crypt_item.mount_point) elif crypt_item.mapper_name: self.logger.log(msg="restoring entry for {0} drive in fstab".format(crypt_item.mapper_name), level=CommonVariables.InfoLevel) self.restore_mount_info(crypt_item.mapper_name) else: self.logger.log(msg=crypt_item.dev_path + " was not in fstab when encryption was enabled, no need to restore", level=CommonVariables.InfoLevel) self.modify_fstab_entry_encrypt(crypt_item.mount_point, os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name)) self.add_crypt_item_to_crypttab(crypt_item, passphrase_file) def is_luks_device(self, device_path, device_header_path): """ checks if the device is set up with a luks header """ path_var = device_header_path if device_header_path else device_path cmd = 'cryptsetup isLuks ' + path_var return (int)(self.command_executor.Execute(cmd, suppress_logging=True)) == CommonVariables.process_success def create_luks_header(self, mapper_name): luks_header_file_path = self.encryption_environment.luks_header_base_path + mapper_name if not os.path.exists(luks_header_file_path): dd_command = self.distro_patcher.dd_path + ' if=/dev/zero bs=33554432 count=1 > ' + luks_header_file_path self.command_executor.ExecuteInBash(dd_command, raise_exception_on_failure=True) return luks_header_file_path def create_cleartext_key(self, mapper_name): cleartext_key_file_path = self.encryption_environment.cleartext_key_base_path + mapper_name if not os.path.exists(cleartext_key_file_path): dd_command = self.distro_patcher.dd_path + ' if=/dev/urandom bs=128 count=1 > ' + cleartext_key_file_path self.command_executor.ExecuteInBash(dd_command, raise_exception_on_failure=True) return cleartext_key_file_path def encrypt_disk(self, dev_path, passphrase_file, mapper_name, header_file): return_code = self.luks_format(passphrase_file=passphrase_file, dev_path=dev_path, header_file=header_file) if return_code != CommonVariables.process_success: self.logger.log(msg=('cryptsetup luksFormat failed, return_code is:{0}'.format(return_code)), level=CommonVariables.ErrorLevel) return return_code else: return_code = self.luks_open(passphrase_file=passphrase_file, dev_path=dev_path, mapper_name=mapper_name, header_file=header_file, uses_cleartext_key=False) if return_code != CommonVariables.process_success: self.logger.log(msg=('cryptsetup luksOpen failed, return_code is:{0}'.format(return_code)), level=CommonVariables.ErrorLevel) return return_code def check_fs(self, dev_path): self.logger.log("checking fs:" + str(dev_path)) check_fs_cmd = self.distro_patcher.e2fsck_path + " -f -y " + dev_path return self.command_executor.Execute(check_fs_cmd) def expand_fs(self, dev_path): expandfs_cmd = self.distro_patcher.resize2fs_path + " " + str(dev_path) return self.command_executor.Execute(expandfs_cmd) def shrink_fs(self, dev_path, size_shrink_to): """ size_shrink_to is in sector (512 byte) """ shrinkfs_cmd = self.distro_patcher.resize2fs_path + ' ' + str(dev_path) + ' ' + str(size_shrink_to) + 's' return self.command_executor.Execute(shrinkfs_cmd) def check_shrink_fs(self, dev_path, size_shrink_to): return_code = self.check_fs(dev_path) if return_code == CommonVariables.process_success: return_code = self.shrink_fs(dev_path=dev_path, size_shrink_to=size_shrink_to) return return_code else: return return_code def luks_format(self, passphrase_file, dev_path, header_file): """ return the return code of the process for error handling. """ self.hutil.log("dev path to cryptsetup luksFormat {0}".format(dev_path)) #walkaround for sles sp3 if self.distro_patcher.distro_info[0].lower() == 'suse' and self.distro_patcher.distro_info[1] == '11': proc_comm = ProcessCommunicator() passphrase_cmd = self.distro_patcher.cat_path + ' ' + passphrase_file self.command_executor.Execute(passphrase_cmd, communicator=proc_comm) passphrase = proc_comm.stdout cryptsetup_cmd = "{0} luksFormat {1} -q".format(self.distro_patcher.cryptsetup_path, dev_path) return self.command_executor.Execute(cryptsetup_cmd, input=passphrase) else: if header_file is not None: cryptsetup_cmd = "{0} luksFormat {1} --header {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, dev_path, header_file, passphrase_file) else: cryptsetup_cmd = "{0} luksFormat {1} -d {2} -q".format(self.distro_patcher.cryptsetup_path, dev_path, passphrase_file) return self.command_executor.Execute(cryptsetup_cmd) def luks_add_key(self, passphrase_file, dev_path, mapper_name, header_file, new_key_path): """ return the return code of the process for error handling. """ self.hutil.log("new key path: " + (new_key_path)) if not os.path.exists(new_key_path): self.hutil.error("new key does not exist") return None if header_file: cryptsetup_cmd = "{0} luksAddKey {1} {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, header_file, new_key_path, passphrase_file) else: cryptsetup_cmd = "{0} luksAddKey {1} {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, dev_path, new_key_path, passphrase_file) return self.command_executor.Execute(cryptsetup_cmd) def luks_remove_key(self, passphrase_file, dev_path, header_file): """ return the return code of the process for error handling. """ self.hutil.log("removing keyslot: {0}".format(passphrase_file)) if header_file: cryptsetup_cmd = "{0} luksRemoveKey {1} -d {2} -q".format(self.distro_patcher.cryptsetup_path, header_file, passphrase_file) else: cryptsetup_cmd = "{0} luksRemoveKey {1} -d {2} -q".format(self.distro_patcher.cryptsetup_path, dev_path, passphrase_file) return self.command_executor.Execute(cryptsetup_cmd) def luks_kill_slot(self, passphrase_file, dev_path, header_file, keyslot): """ return the return code of the process for error handling. """ self.hutil.log("killing keyslot: {0}".format(keyslot)) if header_file: cryptsetup_cmd = "{0} luksKillSlot {1} {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, header_file, keyslot, passphrase_file) else: cryptsetup_cmd = "{0} luksKillSlot {1} {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, dev_path, keyslot, passphrase_file) return self.command_executor.Execute(cryptsetup_cmd) def luks_add_cleartext_key(self, passphrase_file, dev_path, mapper_name, header_file): """ return the return code of the process for error handling. """ cleartext_key_file_path = self.encryption_environment.cleartext_key_base_path + mapper_name self.hutil.log("cleartext key path: " + (cleartext_key_file_path)) return self.luks_add_key(passphrase_file, dev_path, mapper_name, header_file, cleartext_key_file_path) def luks_dump_keyslots(self, dev_path, header_file): cryptsetup_cmd = "" if header_file: cryptsetup_cmd = "{0} luksDump {1}".format(self.distro_patcher.cryptsetup_path, header_file) else: cryptsetup_cmd = "{0} luksDump {1}".format(self.distro_patcher.cryptsetup_path, dev_path) proc_comm = ProcessCommunicator() self.command_executor.Execute(cryptsetup_cmd, communicator=proc_comm) lines = filter(lambda l: "key slot" in l.lower(), proc_comm.stdout.split("\n")) keyslots = map(lambda l: "enabled" in l.lower(), lines) return keyslots def luks_open(self, passphrase_file, dev_path, mapper_name, header_file, uses_cleartext_key): """ return the return code of the process for error handling. """ self.hutil.log("dev mapper name to cryptsetup luksOpen " + (mapper_name)) if uses_cleartext_key: passphrase_file = self.encryption_environment.cleartext_key_base_path + mapper_name self.hutil.log("keyfile: " + (passphrase_file)) if header_file: cryptsetup_cmd = "{0} luksOpen {1} {2} --header {3} -d {4} -q".format(self.distro_patcher.cryptsetup_path, dev_path, mapper_name, header_file, passphrase_file) else: cryptsetup_cmd = "{0} luksOpen {1} {2} -d {3} -q".format(self.distro_patcher.cryptsetup_path, dev_path, mapper_name, passphrase_file) return self.command_executor.Execute(cryptsetup_cmd) def luks_close(self, mapper_name): """ returns the exit code for cryptsetup process. """ self.hutil.log("dev mapper name to cryptsetup luksOpen " + (mapper_name)) cryptsetup_cmd = "{0} luksClose {1} -q".format(self.distro_patcher.cryptsetup_path, mapper_name) return self.command_executor.Execute(cryptsetup_cmd) # TODO error handling. def append_mount_info(self, dev_path, mount_point): shutil.copy2('/etc/fstab', '/etc/fstab.backup.' + str(str(uuid.uuid4()))) mount_content_item = dev_path + " " + mount_point + " auto defaults 0 0" new_mount_content = "" with open("/etc/fstab", 'r') as f: existing_content = f.read() new_mount_content = existing_content + "\n" + mount_content_item with open("/etc/fstab", 'w') as wf: wf.write(new_mount_content) def is_bek_in_fstab_file(self, lines): for line in lines: fstab_device, fstab_mount_point = self.parse_fstab_line(line) if fstab_mount_point == CommonVariables.encryption_key_mount_point: return True return False def parse_fstab_line(self, line): fstab_parts = line.strip().split() if len(fstab_parts) < 2: # Line should have enough content return None, None if fstab_parts[0].startswith("#"): # Line should not be a comment return None, None fstab_device = fstab_parts[0] fstab_mount_point = fstab_parts[1] return fstab_device, fstab_mount_point def modify_fstab_entry_encrypt(self, mount_point, mapper_path): self.logger.log("modify_fstab_entry_encrypt called with mount_point={0}, mapper_path={1}".format(mount_point, mapper_path)) if not mount_point: self.logger.log("modify_fstab_entry_encrypt: mount_point is empty") return shutil.copy2('/etc/fstab', '/etc/fstab.backup.' + str(str(uuid.uuid4()))) with open('/etc/fstab', 'r') as f: lines = f.readlines() relevant_line = None for i in range(len(lines)): line = lines[i] fstab_device, fstab_mount_point = self.parse_fstab_line(line) if fstab_mount_point != mount_point: # Not the line we are looking for continue self.logger.log("Found the relevant fstab line: " + line) relevant_line = line if self.should_use_azure_crypt_mount(): # in this case we just remove the line lines.pop(i) break else: new_line = relevant_line.replace(fstab_device, mapper_path) self.logger.log("Replacing that line with: " + new_line) lines[i] = new_line break if not self.is_bek_in_fstab_file(lines): lines.append(self.get_fstab_bek_line()) with open('/etc/fstab', 'w') as f: f.writelines(lines) if relevant_line is not None: with open('/etc/fstab.azure.backup', 'a+') as f: f.write("\n" + relevant_line) def get_fstab_bek_line(self): if self.distro_patcher.distro_info[0].lower() == 'ubuntu' and self.distro_patcher.distro_info[1].startswith('14'): return CommonVariables.bek_fstab_line_template_ubuntu_14.format(CommonVariables.encryption_key_mount_point) else: return CommonVariables.bek_fstab_line_template.format(CommonVariables.encryption_key_mount_point) def add_bek_to_default_cryptdisks(self): if os.path.exists("/etc/default/cryptdisks"): with open("/etc/default/cryptdisks", 'r') as f: lines = f.readlines() if not any(["azure_bek_disk" in line for line in lines]): with open("/etc/default/cryptdisks", 'a') as f: f.write('\n' + CommonVariables.etc_defaults_cryptdisks_line.format(CommonVariables.encryption_key_mount_point)) def remove_mount_info(self, mount_point): if not mount_point: self.logger.log("remove_mount_info: mount_point is empty") return shutil.copy2('/etc/fstab', '/etc/fstab.backup.' + str(str(uuid.uuid4()))) filtered_contents = [] removed_lines = [] with open('/etc/fstab', 'r') as f: for line in f.readlines(): line = line.strip() pattern = '\s' + re.escape(mount_point) + '\s' if re.search(pattern, line): self.logger.log("removing fstab line: {0}".format(line)) removed_lines.append(line) continue filtered_contents.append(line) with open('/etc/fstab', 'w') as f: f.write('\n') f.write('\n'.join(filtered_contents)) f.write('\n') self.logger.log("fstab updated successfully") with open('/etc/fstab.azure.backup', 'a+') as f: f.write('\n') f.write('\n'.join(removed_lines)) f.write('\n') self.logger.log("fstab.azure.backup updated successfully") def restore_mount_info(self, mount_point_or_mapper_name): if not mount_point_or_mapper_name: self.logger.log("restore_mount_info: mount_point_or_mapper_name is empty") return shutil.copy2('/etc/fstab', '/etc/fstab.backup.' + str(str(uuid.uuid4()))) lines_to_keep_in_backup_fstab = [] lines_to_put_back_to_fstab = [] with open('/etc/fstab.azure.backup', 'r') as f: for line in f.readlines(): line = line.strip() + '\n' pattern = '\s' + re.escape(mount_point_or_mapper_name) + '\s' if re.search(pattern, line): self.logger.log("removing fstab.azure.backup line: {0}".format(line)) lines_to_put_back_to_fstab.append(line) continue lines_to_keep_in_backup_fstab.append(line) with open('/etc/fstab.azure.backup', 'w') as f: f.writelines(lines_to_keep_in_backup_fstab) self.logger.log("fstab.azure.backup updated successfully") lines_that_remain_in_fstab = [] with open('/etc/fstab', 'r') as f: for line in f.readlines(): line = line.strip() + '\n' pattern = '\s' + re.escape(mount_point_or_mapper_name) + '\s' if re.search(pattern, line): # This line should not remain in the fstab. self.logger.log("removing fstab line: {0}".format(line)) continue lines_that_remain_in_fstab.append(line) with open('/etc/fstab', 'w') as f: f.writelines(lines_that_remain_in_fstab + lines_to_put_back_to_fstab) self.logger.log("fstab updated successfully") def mount_bek_volume(self, bek_label, mount_point, option_string): """ mount the BEK volume """ self.make_sure_path_exists(mount_point) mount_cmd = self.distro_patcher.mount_path + ' -L "' + bek_label + '" ' + mount_point + ' -o ' + option_string return self.command_executor.Execute(mount_cmd) def mount_auto(self, dev_path_or_mount_point): """ mount the file system via fstab entry """ mount_cmd = self.distro_patcher.mount_path + ' ' + dev_path_or_mount_point return self.command_executor.Execute(mount_cmd) def mount_filesystem(self, dev_path, mount_point, file_system=None): """ mount the file system. """ self.make_sure_path_exists(mount_point) if file_system is None: mount_cmd = self.distro_patcher.mount_path + ' ' + dev_path + ' ' + mount_point else: mount_cmd = self.distro_patcher.mount_path + ' ' + dev_path + ' ' + mount_point + ' -t ' + file_system return self.command_executor.Execute(mount_cmd) def mount_crypt_item(self, crypt_item, passphrase): self.logger.log("trying to mount the crypt item:" + str(crypt_item)) self.logger.log(msg=('First trying to auto mount for the item')) mount_filesystem_result = self.mount_auto(os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name)) if str(crypt_item.mount_point) != 'None' and mount_filesystem_result != CommonVariables.process_success: self.logger.log(msg=('mount_point is not None and auto mount failed. Trying manual mount.'), level=CommonVariables.WarningLevel) mount_filesystem_result = self.mount_filesystem(os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name), crypt_item.mount_point, crypt_item.file_system) self.logger.log("mount file system result:{0}".format(mount_filesystem_result)) def swapoff(self): return self.command_executor.Execute('swapoff -a') def umount(self, path): umount_cmd = self.distro_patcher.umount_path + ' ' + path return self.command_executor.Execute(umount_cmd) def umount_all_crypt_items(self): for crypt_item in self.get_crypt_items(): self.logger.log("Unmounting {0}".format(os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name))) self.umount(os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name)) def mount_all(self): mount_all_cmd = self.distro_patcher.mount_path + ' -a' return self.command_executor.Execute(mount_all_cmd) def get_mount_items(self): items = [] for line in file('/proc/mounts'): line = [s.decode('string_escape') for s in line.split()] item = { "src": line[0], "dest": line[1], "fs": line[2] } items.append(item) return items def get_encryption_status(self): encryption_status = { "data": "NotEncrypted", "os": "NotEncrypted" } mount_items = self.get_mount_items() device_items = self.get_device_items(None) device_items_dict = dict([(device_item.mount_point, device_item) for device_item in device_items]) os_drive_encrypted = False data_drives_found = False all_data_drives_encrypted = True osmapper_path = os.path.join(CommonVariables.dev_mapper_root, CommonVariables.osmapper_name) if self.is_os_disk_lvm(): grep_result = self.command_executor.ExecuteInBash('pvdisplay | grep {0}'.format(osmapper_path), suppress_logging=True) if grep_result == 0 and not os.path.exists('/volumes.lvm'): self.logger.log("OS PV is encrypted") os_drive_encrypted = True special_azure_devices_to_skip = self.get_azure_devices() for mount_item in mount_items: device_item = device_items_dict.get(mount_item["dest"]) if device_item is not None and \ mount_item["fs"] in CommonVariables.format_supported_file_systems and \ self.is_data_disk(device_item, special_azure_devices_to_skip): data_drives_found = True if not device_item.type == "crypt": self.logger.log("Data volume {0} is mounted from {1}".format(mount_item["dest"], mount_item["src"])) all_data_drives_encrypted = False if mount_item["dest"] == "/" and \ not self.is_os_disk_lvm() and \ CommonVariables.dev_mapper_root in mount_item["src"] or \ "/dev/dm" in mount_item["src"]: self.logger.log("OS volume {0} is mounted from {1}".format(mount_item["dest"], mount_item["src"])) os_drive_encrypted = True if not data_drives_found: encryption_status["data"] = "NotMounted" elif all_data_drives_encrypted: encryption_status["data"] = "Encrypted" if os_drive_encrypted: encryption_status["os"] = "Encrypted" encryption_marker = EncryptionMarkConfig(self.logger, self.encryption_environment) decryption_marker = DecryptionMarkConfig(self.logger, self.encryption_environment) if decryption_marker.config_file_exists(): encryption_status["data"] = "DecryptionInProgress" elif encryption_marker.config_file_exists(): encryption_config = EncryptionConfig(self.encryption_environment, self.logger) volume_type = encryption_config.get_volume_type().lower() if volume_type == CommonVariables.VolumeTypeData.lower() or \ volume_type == CommonVariables.VolumeTypeAll.lower(): encryption_status["data"] = "EncryptionInProgress" if volume_type == CommonVariables.VolumeTypeOS.lower() or \ volume_type == CommonVariables.VolumeTypeAll.lower(): if not os_drive_encrypted: encryption_status["os"] = "EncryptionInProgress" elif os.path.exists(osmapper_path) and not os_drive_encrypted: encryption_status["os"] = "VMRestartPending" return json.dumps(encryption_status) def query_dev_sdx_path_by_scsi_id(self, scsi_number): p = Popen([self.distro_patcher.lsscsi_path, scsi_number], stdout=subprocess.PIPE, stderr=subprocess.PIPE) identity, err = p.communicate() # identity sample: [5:0:0:0] disk Msft Virtual Disk 1.0 /dev/sdc self.logger.log("lsscsi output is: {0}\n".format(identity)) vals = identity.split() if vals is None or len(vals) == 0: return None sdx_path = vals[len(vals) - 1] return sdx_path def query_dev_sdx_path_by_uuid(self, uuid): """ return /dev/disk/by-id that maps to the sdx_path, otherwise return the original path """ desired_uuid_path = os.path.join(CommonVariables.disk_by_uuid_root, uuid) for disk_by_uuid in os.listdir(CommonVariables.disk_by_uuid_root): disk_by_uuid_path = os.path.join(CommonVariables.disk_by_uuid_root, disk_by_uuid) if disk_by_uuid_path == desired_uuid_path: return os.path.realpath(disk_by_uuid_path) return desired_uuid_path def query_dev_id_path_by_sdx_path(self, sdx_path): """ return /dev/disk/by-id that maps to the sdx_path, otherwise return the original path Update: now we have realised that by-id is not a good way to refer to devices (they can change on reallocations or resizes). Try not to use this- use get_stable_path_from_sdx instead """ for disk_by_id in os.listdir(CommonVariables.disk_by_id_root): disk_by_id_path = os.path.join(CommonVariables.disk_by_id_root, disk_by_id) if os.path.realpath(disk_by_id_path) == sdx_path: return disk_by_id_path return sdx_path def get_persistent_path_by_sdx_path(self, sdx_path): """ return a stable path for this /dev/sdx device """ sdx_realpath = os.path.realpath(sdx_path) # First try finding an Azure symlink azure_name_table = self.get_block_device_to_azure_udev_table() if sdx_realpath in azure_name_table: return azure_name_table[sdx_realpath] # A mapper path is also pretty good (especially for raid or lvm) for mapper_name in os.listdir(CommonVariables.dev_mapper_root): mapper_path = os.path.join(CommonVariables.dev_mapper_root, mapper_name) if os.path.realpath(mapper_path) == sdx_realpath: return mapper_path # Then try matching a uuid symlink. Those are probably the best for disk_by_uuid in os.listdir(CommonVariables.disk_by_uuid_root): disk_by_uuid_path = os.path.join(CommonVariables.disk_by_uuid_root, disk_by_uuid) if os.path.realpath(disk_by_uuid_path) == sdx_realpath: return disk_by_uuid_path # Found nothing very persistent. Just return the original sdx path. # And Log it. self.logger.log(msg="Failed to find a persistent path for [{0}].".format(sdx_path), level=CommonVariables.WarningLevel) return sdx_path def get_device_path(self, dev_name): device_path = None if os.path.exists("/dev/" + dev_name): device_path = "/dev/" + dev_name elif os.path.exists("/dev/mapper/" + dev_name): device_path = "/dev/mapper/" + dev_name return device_path def get_device_id(self, dev_path): if (dev_path) in DiskUtil.device_id_cache: return DiskUtil.device_id_cache[dev_path] udev_cmd = "udevadm info -a -p $(udevadm info -q path -n {0}) | grep device_id".format(dev_path) proc_comm = ProcessCommunicator() self.command_executor.ExecuteInBash(udev_cmd, communicator=proc_comm, suppress_logging=True) match = re.findall(r'"{(.*)}"', proc_comm.stdout.strip()) DiskUtil.device_id_cache[dev_path] = match[0] if match else "" return DiskUtil.device_id_cache[dev_path] def get_device_items_property(self, dev_name, property_name): if (dev_name, property_name) in DiskUtil.sles_cache: return DiskUtil.sles_cache[(dev_name, property_name)] self.logger.log("getting property of device {0}".format(dev_name)) device_path = self.get_device_path(dev_name) property_value = "" if property_name == "SIZE": get_property_cmd = self.distro_patcher.blockdev_path + " --getsize64 " + device_path proc_comm = ProcessCommunicator() self.command_executor.Execute(get_property_cmd, communicator=proc_comm, suppress_logging=True) property_value = proc_comm.stdout.strip() elif property_name == "DEVICE_ID": property_value = self.get_device_id(device_path) else: get_property_cmd = self.distro_patcher.lsblk_path + " " + device_path + " -b -nl -o NAME," + property_name proc_comm = ProcessCommunicator() self.command_executor.Execute(get_property_cmd, communicator=proc_comm, raise_exception_on_failure=True, suppress_logging=True) for line in proc_comm.stdout.splitlines(): if line.strip(): disk_info_item_array = line.strip().split() if dev_name == disk_info_item_array[0]: if len(disk_info_item_array) > 1: property_value = disk_info_item_array[1] DiskUtil.sles_cache[(dev_name, property_name)] = property_value return property_value def get_block_device_to_azure_udev_table(self): table = {} if not os.path.exists(CommonVariables.azure_symlinks_dir): return table for top_level_item in os.listdir(CommonVariables.azure_symlinks_dir): top_level_item_full_path = os.path.join(CommonVariables.azure_symlinks_dir, top_level_item) if os.path.isdir(top_level_item_full_path): scsi_path = os.path.join(CommonVariables.azure_symlinks_dir, top_level_item) for symlink in os.listdir(scsi_path): symlink_full_path = os.path.join(scsi_path, symlink) table[os.path.realpath(symlink_full_path)] = symlink_full_path else: table[os.path.realpath(top_level_item_full_path)] = top_level_item_full_path return table def get_azure_symlinks(self): azure_udev_links = {} if os.path.exists(CommonVariables.azure_symlinks_dir): wdbackup = os.getcwd() os.chdir(CommonVariables.azure_symlinks_dir) for symlink in os.listdir(CommonVariables.azure_symlinks_dir): azure_udev_links[os.path.basename(symlink)] = os.path.realpath(symlink) os.chdir(wdbackup) return azure_udev_links def log_lsblk_output(self): lsblk_command = 'lsblk -o NAME,TYPE,FSTYPE,LABEL,SIZE,RO,MOUNTPOINT' proc_comm = ProcessCommunicator() self.command_executor.Execute(lsblk_command, communicator=proc_comm) self.logger.log('\n' + str(proc_comm.stdout) + '\n') def get_device_items_sles(self, dev_path): if dev_path: self.logger.log(msg=("getting blk info for: {0}".format(dev_path))) device_items_to_return = [] device_items = [] #first get all the device names if dev_path is None: lsblk_command = 'lsblk -b -nl -o NAME' else: lsblk_command = 'lsblk -b -nl -o NAME ' + dev_path proc_comm = ProcessCommunicator() self.command_executor.Execute(lsblk_command, communicator=proc_comm, raise_exception_on_failure=True) for line in proc_comm.stdout.splitlines(): item_value_str = line.strip() if item_value_str: device_item = DeviceItem() device_item.name = item_value_str.split()[0] device_items.append(device_item) for device_item in device_items: device_item.file_system = self.get_device_items_property(dev_name=device_item.name, property_name='FSTYPE') device_item.mount_point = self.get_device_items_property(dev_name=device_item.name, property_name='MOUNTPOINT') device_item.label = self.get_device_items_property(dev_name=device_item.name, property_name='LABEL') device_item.uuid = self.get_device_items_property(dev_name=device_item.name, property_name='UUID') device_item.majmin = self.get_device_items_property(dev_name=device_item.name, property_name='MAJ:MIN') device_item.device_id = self.get_device_items_property(dev_name=device_item.name, property_name='DEVICE_ID') device_item.azure_name = '' for symlink, target in self.get_azure_symlinks().items(): if device_item.name in target: device_item.azure_name = symlink # get the type of device model_file_path = '/sys/block/' + device_item.name + '/device/model' if os.path.exists(model_file_path): with open(model_file_path, 'r') as f: device_item.model = f.read().strip() else: self.logger.log(msg=("no model file found for device {0}".format(device_item.name))) if device_item.model == 'Virtual Disk': self.logger.log(msg="model is virtual disk") device_item.type = 'disk' else: partition_files = glob.glob('/sys/block/*/' + device_item.name + '/partition') self.logger.log(msg="partition files exists") if partition_files is not None and len(partition_files) > 0: device_item.type = 'part' size_string = self.get_device_items_property(dev_name=device_item.name, property_name='SIZE') if size_string is not None and size_string != "": device_item.size = int(size_string) if device_item.type is None: device_item.type = '' if device_item.size is not None: device_items_to_return.append(device_item) else: self.logger.log(msg=("skip the device {0} because we could not get size of it.".format(device_item.name))) return device_items_to_return def get_device_items(self, dev_path): if self.distro_patcher.distro_info[0].lower() == 'suse' and self.distro_patcher.distro_info[1] == '11': return self.get_device_items_sles(dev_path) else: if dev_path: self.logger.log(msg=("getting blk info for: " + str(dev_path))) if dev_path is None: lsblk_command = 'lsblk -b -n -P -o NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE,MAJ:MIN' else: lsblk_command = 'lsblk -b -n -P -o NAME,TYPE,FSTYPE,MOUNTPOINT,LABEL,UUID,MODEL,SIZE,MAJ:MIN ' + dev_path proc_comm = ProcessCommunicator() self.command_executor.Execute(lsblk_command, communicator=proc_comm, raise_exception_on_failure=True, suppress_logging=True) device_items = [] lvm_items = self.get_lvm_items() for line in proc_comm.stdout.splitlines(): if line: device_item = DeviceItem() for disk_info_property in line.split(): property_item_pair = disk_info_property.split('=') if property_item_pair[0] == 'SIZE': device_item.size = int(property_item_pair[1].strip('"')) if property_item_pair[0] == 'NAME': device_item.name = property_item_pair[1].strip('"') if property_item_pair[0] == 'TYPE': device_item.type = property_item_pair[1].strip('"') if property_item_pair[0] == 'FSTYPE': device_item.file_system = property_item_pair[1].strip('"') if property_item_pair[0] == 'MOUNTPOINT': device_item.mount_point = property_item_pair[1].strip('"') if property_item_pair[0] == 'LABEL': device_item.label = property_item_pair[1].strip('"') if property_item_pair[0] == 'UUID': device_item.uuid = property_item_pair[1].strip('"') if property_item_pair[0] == 'MODEL': device_item.model = property_item_pair[1].strip('"') if property_item_pair[0] == 'MAJ:MIN': device_item.majmin = property_item_pair[1].strip('"') device_item.device_id = self.get_device_id(self.get_device_path(device_item.name)) if device_item.type is None: device_item.type = '' if device_item.type.lower() == 'lvm': for lvm_item in lvm_items: majmin = lvm_item.lv_kernel_major + ':' + lvm_item.lv_kernel_minor if majmin == device_item.majmin: device_item.name = lvm_item.vg_name + '/' + lvm_item.lv_name device_item.azure_name = '' for symlink, target in self.get_azure_symlinks().items(): if device_item.name in target: device_item.azure_name = symlink device_items.append(device_item) return device_items def get_lvm_items(self): lvs_command = 'lvs --noheadings --nameprefixes --unquoted -o lv_name,vg_name,lv_kernel_major,lv_kernel_minor' proc_comm = ProcessCommunicator() if self.command_executor.Execute(lvs_command, communicator=proc_comm): return [] lvm_items = [] for line in proc_comm.stdout.splitlines(): if not line: continue lvm_item = LvmItem() for pair in line.strip().split(): if len(pair.split('=')) != 2: continue key, value = pair.split('=') if key == 'LVM2_LV_NAME': lvm_item.lv_name = value if key == 'LVM2_VG_NAME': lvm_item.vg_name = value if key == 'LVM2_LV_KERNEL_MAJOR': lvm_item.lv_kernel_major = value if key == 'LVM2_LV_KERNEL_MINOR': lvm_item.lv_kernel_minor = value lvm_items.append(lvm_item) return lvm_items def is_os_disk_lvm(self): if DiskUtil.os_disk_lvm is not None: return DiskUtil.os_disk_lvm device_items = self.get_device_items(None) if not any([item.type.lower() == 'lvm' for item in device_items]): DiskUtil.os_disk_lvm = False return False lvm_items = filter(lambda item: item.vg_name == "rootvg", self.get_lvm_items()) current_lv_names = set([item.lv_name for item in lvm_items]) DiskUtil.os_disk_lvm = False expected_lv_names = set(['homelv', 'optlv', 'rootlv', 'swaplv', 'tmplv', 'usrlv', 'varlv']) if expected_lv_names == current_lv_names: DiskUtil.os_disk_lvm = True expected_lv_names = set(['homelv', 'optlv', 'rootlv', 'tmplv', 'usrlv', 'varlv']) if expected_lv_names == current_lv_names: DiskUtil.os_disk_lvm = True return DiskUtil.os_disk_lvm def is_data_disk(self, device_item, azure_devices): # Root disk if device_item.device_id.startswith('00000000-0000'): self.logger.log(msg="skipping root disk", level=CommonVariables.WarningLevel) return False # Resource Disk. Not considered a "data disk" exactly (is not attached via portal and we have a separate code path for encrypting it) if device_item.device_id.startswith('00000000-0001'): self.logger.log(msg="skipping resource disk", level=CommonVariables.WarningLevel) return False for azure_blk_item in azure_devices: if azure_blk_item.name == device_item.name: self.logger.log(msg="the mountpoint is the azure disk root or resource, so skip it.") return False return True def should_skip_for_inplace_encryption(self, device_item, special_azure_devices_to_skip, encrypt_volume_type): """ TYPE="raid0" TYPE="part" TYPE="crypt" first check whether there's one file system on it. if the type is disk, then to check whether it have child-items, say the part, lvm or crypt luks. if the answer is yes, then skip it. """ if encrypt_volume_type.lower() == 'data' and not self.is_data_disk(device_item, special_azure_devices_to_skip): return True # Skip data disks if device_item.file_system is None or device_item.file_system == "": self.logger.log(msg=("there's no file system on this device: {0}, so skip it.").format(device_item)) return True else: if device_item.size < CommonVariables.min_filesystem_size_support: self.logger.log(msg="the device size is too small," + str(device_item.size) + " so skip it.", level=CommonVariables.WarningLevel) return True supported_device_type = ["disk","part","raid0","raid1","raid5","raid10","lvm"] if device_item.type not in supported_device_type: self.logger.log(msg="the device type: " + str(device_item.type) + " is not supported yet, so skip it.", level=CommonVariables.WarningLevel) return True if device_item.uuid is None or device_item.uuid == "": self.logger.log(msg="the device do not have the related uuid, so skip it.", level=CommonVariables.WarningLevel) return True sub_items = self.get_device_items("/dev/" + device_item.name) if len(sub_items) > 1: self.logger.log(msg=("there's sub items for the device:{0} , so skip it.".format(device_item.name)), level=CommonVariables.WarningLevel) return True if device_item.type == "crypt": self.logger.log(msg=("device_item.type is:{0}, so skip it.".format(device_item.type)), level=CommonVariables.WarningLevel) return True if device_item.mount_point == "/": self.logger.log(msg=("the mountpoint is root:{0}, so skip it.".format(device_item)), level=CommonVariables.WarningLevel) return True for azure_blk_item in special_azure_devices_to_skip: if azure_blk_item.name == device_item.name: self.logger.log(msg="the mountpoint is the azure disk root or resource, so skip it.") return True return False def get_azure_devices(self): ide_devices = self.get_ide_devices() blk_items = [] for ide_device in ide_devices: current_blk_items = self.get_device_items("/dev/" + ide_device) for current_blk_item in current_blk_items: blk_items.append(current_blk_item) return blk_items def get_ide_devices(self): """ this only return the device names of the ide. """ ide_devices = [] for vmbus in os.listdir(self.vmbus_sys_path): f = open('%s/%s/%s' % (self.vmbus_sys_path, vmbus, 'class_id'), 'r') class_id = f.read() f.close() if class_id.strip() == self.ide_class_id: device_sdx_path = self.find_block_sdx_path(vmbus) self.logger.log("found one ide with vmbus: {0} and the sdx path is: {1}".format(vmbus, device_sdx_path)) ide_devices.append(device_sdx_path) return ide_devices def find_block_sdx_path(self, vmbus): device = None for root, dirs, files in os.walk(os.path.join(self.vmbus_sys_path , vmbus)): if root.endswith("/block"): device = dirs[0] else : #older distros for d in dirs: if ':' in d and "block" == d.split(':')[0]: device = d.split(':')[1] break return device ================================================ FILE: VMEncryption/main/EncryptionConfig.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import datetime import os.path from Common import CommonVariables from ConfigParser import ConfigParser from ConfigUtil import ConfigUtil from ConfigUtil import ConfigKeyValuePair class EncryptionConfig(object): def __init__(self, encryption_environment, logger): self.encryption_environment = encryption_environment self.passphrase_file_name = None self.volume_type = None self.secret_id = None self.secret_seq_num = None self.encryption_config = ConfigUtil(encryption_environment.encryption_config_file_path, 'azure_crypt_config', logger) self.logger = logger def config_file_exists(self): return self.encryption_config.config_file_exists() def get_bek_filename(self): return self.encryption_config.get_config(CommonVariables.PassphraseFileNameKey) def get_volume_type(self): return self.encryption_config.get_config(CommonVariables.VolumeTypeKey) def get_secret_id(self): return self.encryption_config.get_config(CommonVariables.SecretUriKey) def get_secret_seq_num(self): return self.encryption_config.get_config(CommonVariables.SecretSeqNum) def commit(self): key_value_pairs = [] command = ConfigKeyValuePair(CommonVariables.PassphraseFileNameKey, self.passphrase_file_name) key_value_pairs.append(command) volume_type = ConfigKeyValuePair(CommonVariables.VolumeTypeKey, self.volume_type) key_value_pairs.append(volume_type) parameters = ConfigKeyValuePair(CommonVariables.SecretUriKey, self.secret_id) key_value_pairs.append(parameters) parameters = ConfigKeyValuePair(CommonVariables.SecretSeqNum, self.secret_seq_num) key_value_pairs.append(parameters) self.encryption_config.save_configs(key_value_pairs) def clear_config(self): try: if os.path.exists(self.encryption_environment.encryption_config_file_path): self.logger.log(msg="archiving the encryption config file: {0}".format(self.encryption_environment.encryption_config_file_path)) time_stamp = datetime.datetime.now() new_name = "{0}_{1}".format(self.encryption_environment.encryption_config_file_path, time_stamp) os.rename(self.encryption_environment.encryption_config_file_path, new_name) else: self.logger.log(msg=("the config file not exist: {0}".format(self.encryption_environment.encryption_config_file_path)), level = CommonVariables.WarningLevel) return True except OSError as e: self.logger.log("Failed to archive encryption config with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return False ================================================ FILE: VMEncryption/main/EncryptionEnvironment.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import subprocess from subprocess import * class EncryptionEnvironment(object): """description of class""" def __init__(self, patching, logger): self.patching = patching self.logger = logger self.encryption_config_path = '/var/lib/azure_disk_encryption_config/' # Lock file for daemon. self.daemon_lock_file_path = os.path.join(self.encryption_config_path, 'daemon_lock_file.lck') self.encryption_config_file_path = os.path.join(self.encryption_config_path, 'azure_crypt_config.ini') self.extension_parameter_file_path = os.path.join(self.encryption_config_path, 'azure_crypt_params.ini') self.azure_crypt_mount_config_path = os.path.join(self.encryption_config_path, 'azure_crypt_mount') self.azure_crypt_request_queue_path = os.path.join(self.encryption_config_path, 'azure_crypt_request_queue.ini') self.azure_decrypt_request_queue_path = os.path.join(self.encryption_config_path, 'azure_decrypt_request_queue.ini') self.azure_crypt_ongoing_item_config_path = os.path.join(self.encryption_config_path, 'azure_crypt_ongoing_item.ini') self.azure_crypt_current_transactional_copy_path = os.path.join(self.encryption_config_path, 'azure_crypt_copy_progress.ini') self.luks_header_base_path = os.path.join(self.encryption_config_path, 'azureluksheader') self.cleartext_key_base_path = os.path.join(self.encryption_config_path, 'cleartext_key') self.copy_header_slice_file_path = os.path.join(self.encryption_config_path, 'copy_header_slice_file') self.copy_slice_item_backup_file = os.path.join(self.encryption_config_path, 'copy_slice_item.bak') self.os_encryption_markers_path = os.path.join(self.encryption_config_path, 'os_encryption_markers') self.bek_backup_path = os.path.join(self.encryption_config_path, 'bek_backup') def get_se_linux(self): proc = Popen([self.patching.getenforce_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE) identity, err = proc.communicate() return identity.strip().lower() def disable_se_linux(self): self.logger.log("disabling se linux") proc = Popen([self.patching.setenforce_path,'0'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) return_code = proc.wait() return return_code def enable_se_linux(self): self.logger.log("enabling se linux") proc = Popen([self.patching.setenforce_path,'1'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) return_code = proc.wait() return return_code ================================================ FILE: VMEncryption/main/EncryptionMarkConfig.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import traceback from ConfigUtil import * from Common import CommonVariables class EncryptionMarkConfig(object): def __init__(self, logger, encryption_environment): self.logger = logger self.encryption_environment = encryption_environment self.command = None self.volume_type = None self.diskFormatQuery = None self.encryption_mark_config = ConfigUtil(self.encryption_environment.azure_crypt_request_queue_path, 'encryption_request_queue', self.logger) def get_volume_type(self): return self.encryption_mark_config.get_config(CommonVariables.EncryptionVolumeTypeKey) def get_current_command(self): return self.encryption_mark_config.get_config(CommonVariables.EncryptionEncryptionOperationKey) def get_encryption_disk_format_query(self): return self.encryption_mark_config.get_config(CommonVariables.EncryptionDiskFormatQueryKey) def config_file_exists(self): """ we should compare the timestamp of the file with the current system time if not match (in 30 minutes, then should skip the file) """ return self.encryption_mark_config.config_file_exists() def commit(self): key_value_pairs = [] command = ConfigKeyValuePair(CommonVariables.EncryptionEncryptionOperationKey, self.command) key_value_pairs.append(command) volume_type = ConfigKeyValuePair(CommonVariables.EncryptionVolumeTypeKey, self.volume_type) key_value_pairs.append(volume_type) disk_format_query = ConfigKeyValuePair(CommonVariables.EncryptionDiskFormatQueryKey, self.diskFormatQuery) key_value_pairs.append(disk_format_query) self.encryption_mark_config.save_configs(key_value_pairs) def clear_config(self): try: if os.path.exists(self.encryption_environment.azure_crypt_request_queue_path): os.remove(self.encryption_environment.azure_crypt_request_queue_path) return True except OSError as e: self.logger.log("Failed to clear_queue with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return False ================================================ FILE: VMEncryption/main/ExtensionParameter.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hashlib import xml.parsers.expat from DiskUtil import DiskUtil from BekUtil import BekUtil from EncryptionConfig import EncryptionConfig from Utils import HandlerUtil from Common import * from ConfigParser import ConfigParser from ConfigUtil import ConfigUtil from ConfigUtil import ConfigKeyValuePair import os.path # parameter format should be like this: #{"command":"enableencryption","query":[{"source_scsi_number":"[5:0:0:0]","target_scsi_number":"[5:0:0:2]"},{"source_scsi_number":"[5:0:0:1]","target_scsi_number":"[5:0:0:3]"}], #"force":"true", "passphrase":"User@123"} class ExtensionParameter(object): def __init__(self, hutil, logger, distro_patcher, encryption_environment, protected_settings, public_settings): """ TODO: we should validate the parameter first """ self.hutil = hutil self.logger = logger self.distro_patcher = distro_patcher self.encryption_environment = encryption_environment self.disk_util = DiskUtil(hutil=hutil, patching=distro_patcher, logger=logger, encryption_environment=encryption_environment) self.bek_util = BekUtil(self.disk_util, logger) self.encryption_config = EncryptionConfig(encryption_environment, logger) self.command = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) self.KeyEncryptionKeyURL = public_settings.get(CommonVariables.KeyEncryptionKeyURLKey) self.KeyVaultURL = public_settings.get(CommonVariables.KeyVaultURLKey) self.AADClientID = public_settings.get(CommonVariables.AADClientIDKey) self.AADClientCertThumbprint = public_settings.get(CommonVariables.AADClientCertThumbprintKey) keyEncryptionAlgorithm = public_settings.get(CommonVariables.KeyEncryptionAlgorithmKey) if keyEncryptionAlgorithm is not None and keyEncryptionAlgorithm !="": self.KeyEncryptionAlgorithm = keyEncryptionAlgorithm else: self.KeyEncryptionAlgorithm = 'RSA-OAEP' self.VolumeType = public_settings.get(CommonVariables.VolumeTypeKey) self.DiskFormatQuery = public_settings.get(CommonVariables.DiskFormatQuerykey) """ private settings """ self.AADClientSecret = protected_settings.get(CommonVariables.AADClientSecretKey) if self.AADClientSecret is None: self.AADClientSecret = '' self.passphrase = protected_settings.get(CommonVariables.PassphraseKey) self.DiskEncryptionKeyFileName = "LinuxPassPhraseFileName" # parse the query from the array self.params_config = ConfigUtil(encryption_environment.extension_parameter_file_path, 'azure_extension_params', logger) def config_file_exists(self): return self.params_config.config_file_exists() def get_command(self): return self.params_config.get_config(CommonVariables.EncryptionEncryptionOperationKey) def get_kek_url(self): return self.params_config.get_config(CommonVariables.KeyEncryptionKeyURLKey) def get_keyvault_url(self): return self.params_config.get_config(CommonVariables.KeyVaultURLKey) def get_aad_client_id(self): return self.params_config.get_config(CommonVariables.AADClientIDKey) def get_aad_client_secret(self): return self.params_config.get_config(CommonVariables.AADClientSecretKey) def get_aad_client_cert(self): return self.params_config.get_config(CommonVariables.AADClientCertThumbprintKey) def get_kek_algorithm(self): return self.params_config.get_config(CommonVariables.KeyEncryptionAlgorithmKey) def get_volume_type(self): return self.params_config.get_config(CommonVariables.VolumeTypeKey) def get_disk_format_query(self): return self.params_config.get_config(CommonVariables.DiskFormatQuerykey) def get_bek_filename(self): return self.DiskEncryptionKeyFileName def commit(self): key_value_pairs = [] command = ConfigKeyValuePair(CommonVariables.EncryptionEncryptionOperationKey, self.command) key_value_pairs.append(command) KeyEncryptionKeyURL = ConfigKeyValuePair(CommonVariables.KeyEncryptionKeyURLKey, self.KeyEncryptionKeyURL) key_value_pairs.append(KeyEncryptionKeyURL) KeyVaultURL = ConfigKeyValuePair(CommonVariables.KeyVaultURLKey, self.KeyVaultURL) key_value_pairs.append(KeyVaultURL) AADClientID = ConfigKeyValuePair(CommonVariables.AADClientIDKey, self.AADClientID) key_value_pairs.append(AADClientID) AADClientSecret = ConfigKeyValuePair(CommonVariables.AADClientSecretKey, hashlib.sha256(self.AADClientSecret.encode("utf-8")).hexdigest()) key_value_pairs.append(AADClientSecret) AADClientCertThumbprint = ConfigKeyValuePair(CommonVariables.AADClientCertThumbprintKey, self.AADClientCertThumbprint) key_value_pairs.append(AADClientCertThumbprint) KeyEncryptionAlgorithm = ConfigKeyValuePair(CommonVariables.KeyEncryptionAlgorithmKey, self.KeyEncryptionAlgorithm) key_value_pairs.append(KeyEncryptionAlgorithm) VolumeType = ConfigKeyValuePair(CommonVariables.VolumeTypeKey, self.VolumeType) key_value_pairs.append(VolumeType) DiskFormatQuery = ConfigKeyValuePair(CommonVariables.DiskFormatQuerykey, self.DiskFormatQuery) key_value_pairs.append(DiskFormatQuery) self.params_config.save_configs(key_value_pairs) def clear_config(self): try: if os.path.exists(self.encryption_environment.encryption_config_file_path): self.logger.log(msg="archiving the encryption config file: {0}".format(self.encryption_environment.encryption_config_file_path)) time_stamp = datetime.datetime.now() new_name = "{0}_{1}".format(self.encryption_environment.encryption_config_file_path, time_stamp) os.rename(self.encryption_environment.encryption_config_file_path, new_name) else: self.logger.log(msg=("the config file not exist: {0}".format(self.encryption_environment.encryption_config_file_path)), level = CommonVariables.WarningLevel) return True except OSError as e: self.logger.log("Failed to archive encryption config with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return False def _is_encrypt_command(self, command): return command in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll] def config_changed(self): if (self.command or self.get_command()) and \ (self.command != self.get_command() and \ # Even if the commands are not exactly the same, if they're both encrypt commands, don't consider this a change not (self._is_encrypt_command(self.command) and self._is_encrypt_command(self.get_command()))): self.logger.log('Current config command {0} differs from effective config command {1}'.format(self.command, self.get_command())) return True if (self.KeyEncryptionKeyURL or self.get_kek_url()) and \ (self.KeyEncryptionKeyURL != self.get_kek_url()): self.logger.log('Current config KeyEncryptionKeyURL {0} differs from effective config KeyEncryptionKeyURL {1}'.format(self.KeyEncryptionKeyURL, self.get_kek_url())) return True if (self.KeyVaultURL or self.get_keyvault_url()) and \ (self.KeyVaultURL != self.get_keyvault_url()): self.logger.log('Current config KeyVaultURL {0} differs from effective config KeyVaultURL {1}'.format(self.KeyVaultURL, self.get_keyvault_url())) return True if (self.AADClientID or self.get_aad_client_id()) and \ (self.AADClientID != self.get_aad_client_id()): self.logger.log('Current config AADClientID {0} differs from effective config AADClientID {1}'.format(self.AADClientID, self.get_aad_client_id())) return True if (self.AADClientSecret or self.get_aad_client_secret()) and \ (hashlib.sha256(self.AADClientSecret.encode("utf-8")).hexdigest() != self.get_aad_client_secret()): self.logger.log('Current config AADClientSecret {0} differs from effective config AADClientSecret {1}'.format(hashlib.sha256(self.AADClientSecret.encode("utf-8")).hexdigest(), self.get_aad_client_secret())) return True if (self.AADClientCertThumbprint or self.get_aad_client_cert()) and \ (self.AADClientCertThumbprint != self.get_aad_client_cert()): self.logger.log('Current config AADClientCertThumbprint {0} differs from effective config AADClientCertThumbprint {1}'.format(self.AADClientCertThumbprint, self.get_aad_client_cert())) return True if (self.KeyEncryptionAlgorithm or self.get_kek_algorithm()) and \ (self.KeyEncryptionAlgorithm != self.get_kek_algorithm()): self.logger.log('Current config KeyEncryptionAlgorithm {0} differs from effective config KeyEncryptionAlgorithm {1}'.format(self.KeyEncryptionAlgorithm, self.get_kek_algorithm())) return True bek_passphrase_file_name = self.bek_util.get_bek_passphrase_file(self.encryption_config) bek_passphrase = None if bek_passphrase_file_name is not None and os.path.exists(bek_passphrase_file_name): bek_passphrase = file(bek_passphrase_file_name).read() if (self.passphrase and bek_passphrase) and \ (self.passphrase != bek_passphrase): self.logger.log('Current config passphrase differs from effective config passphrase') return True self.logger.log('Current config is not different from effective config') return False ================================================ FILE: VMEncryption/main/HttpUtil.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import datetime import traceback import urlparse import httplib import shlex import subprocess from Common import CommonVariables from subprocess import * from Utils.WAAgentUtil import waagent class HttpUtil(object): """description of class""" def __init__(self, logger): self.logger = logger try: waagent.MyDistro = waagent.GetMyDistro() Config = waagent.ConfigurationProvider(None) except Exception as e: errorMsg = "Failed to construct ConfigurationProvider, which may due to the old wala code." self.logger.log(errorMsg) Config = waagent.ConfigurationProvider() self.proxyHost = Config.get("HttpProxy.Host") self.proxyPort = Config.get("HttpProxy.Port") self.connection = None """ snapshot also called this. so we should not write the file/read the file in this method. """ def Call(self, method, http_uri, data, headers): try: uri_obj = urlparse.urlparse(http_uri) #parse the uri str here if self.proxyHost is None or self.proxyPort is None: self.connection = httplib.HTTPSConnection(uri_obj.hostname, timeout = 10) if uri_obj.query is not None: self.connection.request(method = method, url=(uri_obj.path +'?'+ uri_obj.query), body = data, headers = headers) else: self.connection.request(method = method, url=(uri_obj.path), body = data, headers = headers) resp = self.connection.getresponse() else: self.logger.log("proxyHost is not empty, so use the proxy to call the http.") self.connection = httplib.HTTPSConnection(self.proxyHost, self.proxyPort, timeout = 10) if uri_obj.scheme.lower() == "https": self.connection.set_tunnel(uri_obj.hostname, 443) else: self.connection.set_tunnel(uri_obj.hostname, 80) self.connection.request(method = method, url = (http_uri), body = data, headers = headers) resp = self.connection.getresponse() return resp except Exception as e: errorMsg = "Failed to call http with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) self.logger.log(errorMsg) return None ================================================ FILE: VMEncryption/main/KeyVaultUtil.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import httplib import urllib import json import uuid import base64 import traceback import re import os import subprocess from tempfile import mkstemp from HttpUtil import HttpUtil from urlparse import urlparse class KeyVaultUtil(object): def __init__(self, logger): self.api_version = "2015-06-01" self.logger = logger def urljoin(self,*args): """ Joins given arguments into a url. Trailing but not leading slashes are stripped for each argument. """ return "/".join(map(lambda x: str(x).rstrip('/'), args)) """ The Passphrase is a plain encoded string. before the encryption it would be base64encoding. return the secret uri if creation successfully. """ def create_kek_secret(self, Passphrase, KeyVaultURL, KeyEncryptionKeyURL, AADClientID, AADClientCertThumbprint, KeyEncryptionAlgorithm, AADClientSecret, DiskEncryptionKeyFileName): try: self.logger.log("start creating kek secret") passphrase_encoded = base64.standard_b64encode(Passphrase) keys_uri = self.urljoin(KeyVaultURL, "keys") http_util = HttpUtil(self.logger) headers = {} result = http_util.Call(method='GET', http_uri=keys_uri, data=None, headers=headers) http_util.connection.close() """ get the access token """ self.logger.log("getting the access token.") bearerHeader = result.getheader("www-authenticate") authorize_uri = self.get_authorize_uri(bearerHeader) if authorize_uri is None: self.logger.log("the authorize uri is None") return None parsed_url = urlparse(KeyVaultURL) vault_domain = re.findall(r".*(vault.*)", parsed_url.netloc)[0] kv_resource_name = parsed_url.scheme + '://' + vault_domain access_token = self.get_access_token(kv_resource_name, authorize_uri, AADClientID, AADClientCertThumbprint, AADClientSecret) if access_token is None: self.logger.log("the access token is None") return None """ we should skip encrypting the passphrase if the KeyVaultURL and KeyEncryptionKeyURL is empty """ if KeyEncryptionKeyURL is None or KeyEncryptionKeyURL == "": secret_value = passphrase_encoded else: secret_value = self.encrypt_passphrase(access_token, passphrase_encoded, KeyVaultURL, KeyEncryptionKeyURL, AADClientID, KeyEncryptionAlgorithm, AADClientSecret) if secret_value is None: self.logger.log("secret value is None") return None secret_id = self.create_secret(access_token, KeyVaultURL, secret_value, KeyEncryptionAlgorithm, DiskEncryptionKeyFileName) return secret_id except Exception as e: self.logger.log("Failed to create_kek_secret with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) raise def is_adal_available(self): try: import adal self.logger.log('Python ADAL library is natively available on the system') return True except: self.logger.log('Python ADAL library is not natively available on the system') return False def is_scl_adal_available(self): try: subprocess.check_call(['scl', 'enable', 'python27', "python -c 'import adal'"]) self.logger.log('Python ADAL library is available on the system via SCL') return True except: self.logger.log('Python ADAL library is not available on the system via SCL') return False def get_access_token_with_certificate(self, KeyVaultResourceName, AuthorizeUri, AADClientID, AADClientCertThumbprint): # construct path to the private key file which is stored and managed by waagent inside of the lib directory import waagent prv_path = os.path.join(waagent.LibDir, AADClientCertThumbprint.upper() + '.prv') if self.is_adal_available(): import adal prv_data = waagent.GetFileContents(prv_path) context = adal.AuthenticationContext(AuthorizeUri) result_json = context.acquire_token_with_client_certificate(KeyVaultResourceName, AADClientID, prv_data, AADClientCertThumbprint) access_token = result_json["accessToken"] return access_token elif self.is_scl_adal_available(): # On RHEL, support for python-pip and the adal library are made available outside of default python via SCL tmp_data = { "auth": AuthorizeUri, "resource": KeyVaultResourceName, "client": AADClientID, "certificate": prv_path, "thumbprint": AADClientCertThumbprint} tmp_fd, tmp_path = mkstemp() with open(tmp_path,'w') as tmp_file: json.dump(tmp_data,tmp_file) os.close(tmp_fd) tok_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),'TokenUtil.py') scl_args = 'python ' + tok_path + ' ' + tmp_path access_token = subprocess.check_output(['scl', 'enable', 'python27', scl_args]).rstrip() if os.path.isfile(tmp_path): os.remove(tmp_path) return access_token else: raise Exception('Python ADAL library required for client certificate authentication was not found') def get_access_token(self, KeyVaultResourceName, AuthorizeUri, AADClientID, AADClientCertThumbprint, AADClientSecret): if not AADClientSecret and not AADClientCertThumbprint: raise ValueError("Missing Credentials. Either AADClientSecret or AADClientCertThumbprint must be specified") if AADClientSecret and AADClientCertThumbprint: raise ValueError("Both AADClientSecret and AADClientCertThumbprint were supplied, when only one of these was expected.") if AADClientCertThumbprint: return self.get_access_token_with_certificate(KeyVaultResourceName, AuthorizeUri, AADClientID, AADClientCertThumbprint) else: # retrieve access token directly, adal library not required token_uri = AuthorizeUri + "/oauth2/token" request_content = "resource=" + urllib.quote(KeyVaultResourceName) + "&client_id=" + AADClientID + "&client_secret=" + urllib.quote(AADClientSecret) + "&grant_type=client_credentials" headers = {} http_util = HttpUtil(self.logger) result = http_util.Call(method='POST', http_uri=token_uri, data=request_content, headers=headers) self.logger.log("{0} {1}".format(result.status, result.getheaders())) result_content = result.read() if result.status != httplib.OK and result.status != httplib.ACCEPTED: self.logger.log(str(result_content)) return None http_util.connection.close() result_json = json.loads(result_content) access_token = result_json["access_token"] return access_token """ return the encrypted secret uri if success. else return None """ def encrypt_passphrase(self, AccessToken, Passphrase, KeyVaultURL, KeyEncryptionKeyURL, AADClientID, KeyEncryptionAlgorithm, AADClientSecret): try: """ wrap our passphrase using the encryption key api ref for wrapkey: https://msdn.microsoft.com/en-us/library/azure/dn878066.aspx """ self.logger.log("encrypting the secret using key: " + KeyEncryptionKeyURL) request_content = '{"alg":"' + str(KeyEncryptionAlgorithm) + '","value":"' + str(Passphrase) + '"}' headers = {} headers["Content-Type"] = "application/json" headers["Authorization"] = "Bearer " + str(AccessToken) relative_path = KeyEncryptionKeyURL + "/wrapkey" + '?api-version=' + self.api_version http_util = HttpUtil(self.logger) result = http_util.Call(method='POST', http_uri=relative_path, data=request_content, headers=headers) result_content = result.read() self.logger.log("result_content is: {0}".format(result_content)) self.logger.log("{0} {1}".format(result.status, result.getheaders())) if result.status != httplib.OK and result.status != httplib.ACCEPTED: return None http_util.connection.close() result_json = json.loads(result_content) secret_value = result_json[u'value'] return secret_value except Exception as e: self.logger.log("Failed to encrypt_passphrase with error: {0}, stack trace: %s".format(e, traceback.format_exc())) return None def create_secret(self, AccessToken, KeyVaultURL, secret_value, KeyEncryptionAlgorithm, DiskEncryptionKeyFileName): """ create secret api https://msdn.microsoft.com/en-us/library/azure/dn903618.aspx https://mykeyvault.vault.azure.net/secrets/{secret-name}?api-version={api-version} """ try: secret_name = str(uuid.uuid4()) secret_keyvault_uri = self.urljoin(KeyVaultURL, "secrets", secret_name) self.logger.log("secret_keyvault_uri is: {0} and keyvault_uri is:{1}".format(secret_keyvault_uri, KeyVaultURL)) if KeyEncryptionAlgorithm is None: request_content = '{{"value":"{0}","attributes":{{"enabled":"true"}},"tags":{{"DiskEncryptionKeyFileName":"{1}"}}}}'\ .format(str(secret_value), DiskEncryptionKeyFileName) else: request_content = '{{"value":"{0}","attributes":{{"enabled":"true"}},"tags":{{"DiskEncryptionKeyEncryptionAlgorithm":"{1}","DiskEncryptionKeyFileName":"{2}"}}}}'\ .format(str(secret_value), KeyEncryptionAlgorithm, DiskEncryptionKeyFileName) http_util = HttpUtil(self.logger) headers = {} headers["Content-Type"] = "application/json" headers["Authorization"] = "Bearer " + AccessToken result = http_util.Call(method='PUT', http_uri=secret_keyvault_uri + '?api-version=' + self.api_version, data=request_content, headers=headers) self.logger.log("{0} {1}".format(result.status, result.getheaders())) result_content = result.read() # Do NOT log the result_content. It contains the uploaded secret and we don't want that in the logs. result_json = json.loads(result_content) secret_id = result_json["id"] http_util.connection.close() if result.status != httplib.OK and result.status != httplib.ACCEPTED: self.logger.log("the result status failed.") return None return secret_id except Exception as e: self.logger.log("Failed to create_secret with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return None def get_authorize_uri(self, bearerHeader): """ Bearer authorization="https://login.windows.net/72f988bf-86f1-41af-91ab-2d7cd011db47", resource="https://vault.azure.net" """ try: self.logger.log("trying to get the authorize uri from: " + str(bearerHeader)) bearerString = str(bearerHeader) authorization_key = 'authorization="' authoirzation_index = bearerString.index(authorization_key) bearerString = bearerString[(authoirzation_index + len(authorization_key)):] bearerString = bearerString[0:bearerString.index('"')] return bearerString except Exception as e: self.logger.log("Failed to get_authorize_uri with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return None ================================================ FILE: VMEncryption/main/MachineIdentity.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess import xml import xml.dom.minidom class MachineIdentity: def __init__(self): self.store_identity_file = './machine_identity_FD76C85E-406F-4CFA-8EB0-CF18B123365C' def current_identity(self): with open("/var/lib/waagent/HostingEnvironmentConfig.xml",'r') as file: xmlText = file.read() dom = xml.dom.minidom.parseString(xmlText) deployment = dom.getElementsByTagName("Role") identity = deployment[0].getAttribute("guid") return identity def save_identity(self): with open(self.store_identity_file,'w') as file: machine_identity = self.current_identity() file.write(machine_identity) def stored_identity(self): identity_stored = None if os.path.exists(self.store_identity_file): with open(self.store_identity_file,'r') as file: identity_stored = file.read() return identity_stored ================================================ FILE: VMEncryption/main/OnGoingItemConfig.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path import uuid import time import datetime from Common import CommonVariables from ConfigParser import ConfigParser from ConfigUtil import ConfigUtil from ConfigUtil import ConfigKeyValuePair class OnGoingItemConfig(object): def __init__(self, encryption_environment, logger): self.encryption_environment = encryption_environment self.logger = logger self.original_dev_name_path = None self.original_dev_path = None self.mapper_name = None self.luks_header_file_path = None self.phase = None self.file_system = None self.mount_point = None self.device_size = None self.from_end = None self.header_slice_file_path = None self.current_block_size = None self.current_source_path = None self.current_total_copy_size = None self.current_slice_index = None self.current_destination = None self.ongoing_item_config = ConfigUtil(encryption_environment.azure_crypt_ongoing_item_config_path, 'azure_crypt_ongoing_item_config', logger) def config_file_exists(self): return self.ongoing_item_config.config_file_exists() def get_original_dev_name_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemOriginalDevNamePathKey) def get_original_dev_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemOriginalDevPathKey) def get_mapper_name(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemMapperNameKey) def get_header_file_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemHeaderFilePathKey) def get_phase(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemPhaseKey) def get_header_slice_file_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemHeaderSliceFilePathKey) def get_file_system(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemFileSystemKey) def get_mount_point(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemMountPointKey) def get_device_size(self): device_size_value = self.ongoing_item_config.get_config(CommonVariables.OngoingItemDeviceSizeKey) if device_size_value is None or device_size_value == "": return None else: return long(device_size_value) def get_current_slice_index(self): current_slice_index_value = self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentSliceIndexKey) if current_slice_index_value is None or current_slice_index_value == "": return None else: return long(current_slice_index_value) def get_from_end(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemFromEndKey) def get_current_block_size(self): block_size_value = self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentBlockSizeKey) if block_size_value is None or block_size_value == "": return None else: return long(block_size_value) def get_current_source_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentSourcePathKey) def get_current_destination(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentDestinationKey) def get_current_total_copy_size(self): total_copy_size_value = self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentTotalCopySizeKey) if total_copy_size_value is None or total_copy_size_value == "": return None else: return long(total_copy_size_value) def get_luks_header_file_path(self): return self.ongoing_item_config.get_config(CommonVariables.OngoingItemCurrentLuksHeaderFilePathKey) def load_value_from_file(self): self.original_dev_name_path = self.get_original_dev_name_path() self.original_dev_path = self.get_original_dev_path() self.mapper_name = self.get_mapper_name() self.luks_header_file_path = self.get_luks_header_file_path() self.phase = self.get_phase() self.file_system = self.get_file_system() self.mount_point = self.get_mount_point() self.device_size = self.get_device_size() self.from_end = self.get_from_end() self.header_slice_file_path = self.get_header_slice_file_path() self.current_block_size = self.get_current_block_size() self.current_source_path = self.get_current_source_path() self.current_total_copy_size = self.get_current_total_copy_size() self.current_slice_index = self.get_current_slice_index() self.current_destination = self.get_current_destination() def commit(self): key_value_pairs = [] original_dev_name_path_pair = ConfigKeyValuePair(CommonVariables.OngoingItemOriginalDevNamePathKey, self.original_dev_name_path) key_value_pairs.append(original_dev_name_path_pair) original_dev_path_pair = ConfigKeyValuePair(CommonVariables.OngoingItemOriginalDevPathKey, self.original_dev_path) key_value_pairs.append(original_dev_path_pair) mapper_name_pair = ConfigKeyValuePair(CommonVariables.OngoingItemMapperNameKey, self.mapper_name) key_value_pairs.append(mapper_name_pair) header_file_pair = ConfigKeyValuePair(CommonVariables.OngoingItemHeaderFilePathKey, self.luks_header_file_path) key_value_pairs.append(header_file_pair) phase_pair = ConfigKeyValuePair(CommonVariables.OngoingItemPhaseKey, self.phase) key_value_pairs.append(phase_pair) header_slice_file_pair = ConfigKeyValuePair(CommonVariables.OngoingItemHeaderSliceFilePathKey, self.header_slice_file_path) key_value_pairs.append(header_slice_file_pair) file_system_pair = ConfigKeyValuePair(CommonVariables.OngoingItemFileSystemKey, self.file_system) key_value_pairs.append(file_system_pair) mount_point_pair = ConfigKeyValuePair(CommonVariables.OngoingItemMountPointKey, self.mount_point) key_value_pairs.append(mount_point_pair) device_size_pair = ConfigKeyValuePair(CommonVariables.OngoingItemDeviceSizeKey, self.device_size) key_value_pairs.append(device_size_pair) current_slice_index_pair = ConfigKeyValuePair(CommonVariables.OngoingItemCurrentSliceIndexKey, self.current_slice_index) key_value_pairs.append(current_slice_index_pair) from_end_pair = ConfigKeyValuePair(CommonVariables.OngoingItemFromEndKey, self.from_end) key_value_pairs.append(from_end_pair) current_source_path_pair = ConfigKeyValuePair(CommonVariables.OngoingItemCurrentSourcePathKey, self.current_source_path) key_value_pairs.append(current_source_path_pair) current_destination_pair = ConfigKeyValuePair(CommonVariables.OngoingItemCurrentDestinationKey, self.current_destination) key_value_pairs.append(current_destination_pair) current_total_copy_size_pair = ConfigKeyValuePair(CommonVariables.OngoingItemCurrentTotalCopySizeKey, self.current_total_copy_size) key_value_pairs.append(current_total_copy_size_pair) current_block_size_pair = ConfigKeyValuePair(CommonVariables.OngoingItemCurrentBlockSizeKey, self.current_block_size) key_value_pairs.append(current_block_size_pair) self.ongoing_item_config.save_configs(key_value_pairs) def clear_config(self): try: if os.path.exists(self.encryption_environment.azure_crypt_ongoing_item_config_path): self.logger.log(msg="archive the config file: {0}".format(self.encryption_environment.azure_crypt_ongoing_item_config_path)) time_stamp = datetime.datetime.now() new_name = "{0}_{1}".format(self.encryption_environment.azure_crypt_ongoing_item_config_path, time_stamp) os.rename(self.encryption_environment.azure_crypt_ongoing_item_config_path, new_name) else: self.logger.log(msg=("the config file not exist: {0}".format(self.encryption_environment.azure_crypt_ongoing_item_config_path)), level = CommonVariables.WarningLevel) return True except OSError as e: self.logger.log("Failed to archive_backup_config with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) return False def __str__(self): return "dev_uuid_path is {0}, mapper_name is {1}, luks_header_file_path is {2}, phase is {3}, header_slice_file_path is {4}, file system is {5}, mount_point is {6}, device size is {7}"\ .format(self.original_dev_path, self.mapper_name, self.luks_header_file_path, self.phase, self.header_slice_file_path, self.file_system, self.mount_point, self.device_size) ================================================ FILE: VMEncryption/main/ProcessLock.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os.path import fcntl from Common import CommonVariables class ProcessLock(object): def __init__(self, logger, lock_file_path): self.logger = logger self.lock_file_path = lock_file_path self.fd = None def try_lock(self): try: self.fd = open(self.lock_file_path, "w") fcntl.flock(self.fd, fcntl.LOCK_EX) return True except Exception as e: self.logger.log("could not acquire a lock, error: {0}".format(str(e))) return False def release_lock(self): fcntl.flock(self.fd, fcntl.LOCK_UN) self.fd.close() ================================================ FILE: VMEncryption/main/ResourceDiskUtil.py ================================================ #!/usr/bin/env python # # ********************************************************* # Copyright (c) Microsoft. All rights reserved. # # Apache 2.0 License # # You may obtain a copy of the License at # http:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ********************************************************* """ Functionality to encrypt the Azure resource disk""" import time import os from CommandExecutor import CommandExecutor from Common import CommonVariables, CryptItem class ResourceDiskUtil(object): """ Resource Disk Encryption Utilities """ RD_MOUNT_POINT = '/mnt/resource' RD_BASE_DEV_PATH = os.path.join(CommonVariables.azure_symlinks_dir, 'resource') RD_DEV_PATH = os.path.join(CommonVariables.azure_symlinks_dir, 'resource-part1') DEV_DM_PREFIX = '/dev/dm-' # todo: consolidate this and other key file path references # (BekUtil.py, ExtensionParameter.py, and dracut patches) RD_MAPPER_NAME = 'resourceencrypt' RD_MAPPER_PATH = os.path.join(CommonVariables.dev_mapper_root, RD_MAPPER_NAME) def __init__(self, logger, disk_util, passphrase_filename, public_settings, distro_info): self.logger = logger self.executor = CommandExecutor(self.logger) self.disk_util = disk_util self.passphrase_filename = passphrase_filename # WARNING: This may be null, in which case we mount the resource disk if its unencrypted and do nothing if it is. self.public_settings = public_settings self.distro_info = distro_info def _is_encrypt_format_all(self): """ return true if current encryption operation is EncryptFormatAll """ encryption_operation = self.public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation in [CommonVariables.EnableEncryptionFormatAll]: return True self.logger.log("Current encryption operation is not EnableEncryptionFormatAll") return False def _is_luks_device(self): """ checks if the device is set up with a luks header """ if not self._resource_disk_partition_exists(): return False cmd = 'cryptsetup isLuks ' + self.RD_DEV_PATH return (int)(self.executor.Execute(cmd, suppress_logging=True)) == CommonVariables.process_success def _resource_disk_partition_exists(self): """ true if udev name for resource disk partition exists """ cmd = 'test -b ' + self.RD_DEV_PATH return (int)(self.executor.Execute(cmd, suppress_logging=True)) == CommonVariables.process_success def _encrypt(self): """ use disk util with the appropriate device mapper """ return (int)(self.disk_util.encrypt_disk(dev_path=self.RD_DEV_PATH, passphrase_file=self.passphrase_filename, mapper_name=self.RD_MAPPER_NAME, header_file=None)) == CommonVariables.process_success def _format_encrypted_partition(self): """ make a default file system on top of the crypt layer """ make_result = self.disk_util.format_disk(dev_path=self.RD_MAPPER_PATH, file_system=CommonVariables.default_file_system) if make_result != CommonVariables.process_success: self.logger.log(msg="Failed to make file system on ephemeral disk", level=CommonVariables.ErrorLevel) return False # todo - drop DATALOSS_WARNING_README.txt file to disk return True def _mount_resource_disk(self, dev_path): """ mount the file system previously made on top of the crypt layer """ # ensure that resource disk mount point directory has been created cmd = 'mkdir -p ' + self.RD_MOUNT_POINT if self.executor.Execute(cmd, suppress_logging=True) != CommonVariables.process_success: self.logger.log(msg='Failed to precreate mount point directory: ' + cmd, level=CommonVariables.ErrorLevel) return False # mount to mount point directory mount_result = self.disk_util.mount_filesystem(dev_path=dev_path, mount_point=self.RD_MOUNT_POINT) if mount_result != CommonVariables.process_success: self.logger.log(msg="Failed to mount file system on resource disk", level=CommonVariables.ErrorLevel) return False return True def _configure_waagent(self): """ turn off waagent.conf resource disk management """ # set ResourceDisk.MountPoint to standard mount point cmd = "sed -i.rdbak1 's|ResourceDisk.MountPoint=.*|ResourceDisk.MountPoint=" + self.RD_MOUNT_POINT + "|' /etc/waagent.conf" if self.executor.ExecuteInBash(cmd) != CommonVariables.process_success: self.logger.log(msg="Failed to change ResourceDisk.MountPoint in /etc/waagent.conf", level=CommonVariables.WarningLevel) return False # set ResourceDiskFormat=n to ensure waagent does not attempt a simultaneous format cmd = "sed -i.rdbak2 's|ResourceDisk.Format=y|ResourceDisk.Format=n|' /etc/waagent.conf" if self.executor.ExecuteInBash(cmd) != CommonVariables.process_success: self.logger.log(msg="Failed to set ResourceDiskFormat in /etc/waagent.conf", level=CommonVariables.WarningLevel) return False # todo: restart waagent if necessary to ensure changes are picked up? return True def _configure_fstab(self): """ remove resource disk from /etc/fstab if present """ cmd = "sed -i.bak '/azure_resource-part1/d' /etc/fstab" if self.executor.ExecuteInBash(cmd) != CommonVariables.process_success: self.logger.log(msg="Failed to configure resource disk entry of /etc/fstab", level=CommonVariables.WarningLevel) return False return True def _unmount_resource_disk(self): """ unmount resource disk """ self.disk_util.umount(self.RD_MOUNT_POINT) self.disk_util.umount(CommonVariables.encryption_key_mount_point) self.disk_util.umount('/mnt') self.disk_util.make_sure_path_exists(CommonVariables.encryption_key_mount_point) self.disk_util.mount_bek_volume("BEK VOLUME", CommonVariables.encryption_key_mount_point, "fmask=077") def _is_plain_mounted(self): """ return true if mount point is mounted from a non-crypt layer """ mount_items = self.disk_util.get_mount_items() for mount_item in mount_items: if mount_item["dest"] == self.RD_MOUNT_POINT and not (mount_item["src"].startswith(CommonVariables.dev_mapper_root) or mount_item["src"].startswith(self.DEV_DM_PREFIX)): return True return False def _is_crypt_mounted(self): """ return true if mount point is already on a crypt layer """ mount_items = self.disk_util.get_mount_items() for mount_item in mount_items: if mount_item["dest"] == self.RD_MOUNT_POINT and (mount_item["src"].startswith(CommonVariables.dev_mapper_root) or mount_item["src"].startswith(self.DEV_DM_PREFIX)): return True return False def _get_rd_device_mappers(self): """ Retreive any device mapper device on the resource disk (e.g. /dev/dm-0). Can't imagine why there would be multiple device mappers here, but doesn't hurt to handle the case """ device_items = self.disk_util.get_device_items(self.RD_DEV_PATH) device_mappers = [] mapper_device_types = ["raid0", "raid1", "raid5", "raid10", "lvm", "crypt"] for device_item in device_items: # fstype should be crypto_LUKS dev_path = self.disk_util.get_device_path(device_item.name) if device_item.type in mapper_device_types: device_mappers.append(device_item) self.logger.log('Found device mapper: ' + dev_path, level='Info') return device_mappers def _remove_device_mappers(self): """ Use dmsetup to remove the resource disk device mapper if it exists. This is to allow us to make sure that the resource disk is not being used by anything and we can safely luksFormat it. """ # There could be a dependency between the something_closed = True while something_closed is True: # The mappers might be dependant on each other, like a crypt on an LVM. # Instead of trying to figure out the dependency tree we will try to close anything we can # and if anything does get closed we will refresh the list of devices and try to close everything again. # In effect we repeat until we either close everything or we reach a point where we can't close anything. dm_items = self._get_rd_device_mappers() something_closed = False if len(dm_items) == 0: self.logger.log('no resource disk device mapper found') for dm_item in dm_items: # try luksClose cmd = 'cryptsetup luksClose ' + dm_item.name if self.executor.Execute(cmd) == CommonVariables.process_success: self.logger.log('Successfully closed cryptlayer: ' + dm_item.name) something_closed = True else: # try a dmsetup remove, in case its non-crypt device mapper (lvm, raid, something we don't know) cmd = 'dmsetup remove ' + self.disk_util.get_device_path(dm_item.name) if self.executor.Execute(cmd) == CommonVariables.process_success: something_closed = True else: self.logger.log('failed to remove ' + dm_item.name) def _prepare_partition(self): """ create partition on resource disk if missing """ if self._resource_disk_partition_exists(): return True self.logger.log("resource disk partition does not exist", level='Info') cmd = 'parted ' + self.RD_BASE_DEV_PATH + ' mkpart primary ext4 0% 100%' if self.executor.ExecuteInBash(cmd) == CommonVariables.process_success: # wait for the corresponding udev name to become available for i in range(0, 10): time.sleep(i) if self._resource_disk_partition_exists(): return True self.logger.log('unable to make resource disk partition') return False def _wipe_partition_header(self): """ clear any possible header (luke or filesystem) by overwriting with 10MB of entropy """ if not self._resource_disk_partition_exists(): self.logger.log("resource partition does not exist, no header to clear") return True cmd = 'dd if=/dev/urandom of=' + self.RD_DEV_PATH + ' bs=512 count=20480' return self.executor.Execute(cmd) == CommonVariables.process_success def try_remount(self): """ Mount the resource disk if not already mounted Returns true if the resource disk is mounted, false otherwise Throws an exception if anything goes wrong """ self.logger.log("In try_remount") if self.passphrase_filename: self.logger.log("passphrase_filename(value={0}) is not null, so trying to mount encrypted Resource Disk".format(self.passphrase_filename)) if self._is_crypt_mounted(): self.logger.log("Resource disk already encrypted and mounted") # Add resource disk to crypttab if crypt mount is used # Scenario: RD is alreday crypt mounted and crypt mount to crypttab migration is initiated if not self.disk_util.should_use_azure_crypt_mount(): self.add_resource_disk_to_crypttab() return True if self._resource_disk_partition_exists() and self._is_luks_device(): self.disk_util.luks_open(passphrase_file=self.passphrase_filename, dev_path=self.RD_DEV_PATH, mapper_name=self.RD_MAPPER_NAME, header_file=None, uses_cleartext_key=False) self.logger.log("Trying to mount resource disk.") mount_retval = self._mount_resource_disk(self.RD_MAPPER_PATH) if mount_retval: # We successfully mounted the RD but # the RD was not auto-mounted, so trying to enable auto-unlock for RD self.add_resource_disk_to_crypttab() return mount_retval else: self.logger.log("passphrase_filename(value={0}) is null, so trying to mount plain Resource Disk".format(self.passphrase_filename)) if self._is_plain_mounted(): self.logger.log("Resource disk already encrypted and mounted") return True return self._mount_resource_disk(self.RD_DEV_PATH) # conditions required to re-mount were not met return False def prepare(self): """ prepare a non-encrypted resource disk to be encrypted """ self._configure_waagent() self._configure_fstab() if self._resource_disk_partition_exists(): self.disk_util.swapoff() self._unmount_resource_disk() self._remove_device_mappers() self._wipe_partition_header() self._prepare_partition() return True def add_to_fstab(self): with open("/etc/fstab") as f: lines = f.readlines() if not self.disk_util.is_bek_in_fstab_file(lines): lines.append(self.disk_util.get_fstab_bek_line()) self.disk_util.add_bek_to_default_cryptdisks() if not any([line.startswith(self.RD_MAPPER_PATH) for line in lines]): if self.distro_info[0].lower() == 'ubuntu' and self.distro_info[1].startswith('14'): lines.append('{0} {1} auto defaults,discard,nobootwait 0 0\n'.format(self.RD_MAPPER_PATH, self.RD_MOUNT_POINT)) else: lines.append('{0} {1} auto defaults,discard,nofail 0 0\n'.format(self.RD_MAPPER_PATH, self.RD_MOUNT_POINT)) with open('/etc/fstab', 'w') as f: f.writelines(lines) def encrypt_format_mount(self): if not self.prepare(): self.logger.log("Failed to prepare VM for Resource Disk Encryption", CommonVariables.ErrorLevel) return False if not self._encrypt(): self.logger.log("Failed to encrypt Resource Disk Encryption", CommonVariables.ErrorLevel) return False if not self._format_encrypted_partition(): self.logger.log("Failed to format the encrypted Resource Disk Encryption", CommonVariables.ErrorLevel) return False if not self._mount_resource_disk(self.RD_MAPPER_PATH): self.logger.log("Failed to mount after formatting and encrypting the Resource Disk Encryption", CommonVariables.ErrorLevel) return False # We haven't failed so far, lets just add the RD to crypttab self.add_resource_disk_to_crypttab() return True def add_resource_disk_to_crypttab(self): self.logger.log("Adding resource disk to the crypttab file") crypt_item = CryptItem() crypt_item.dev_path = self.RD_DEV_PATH crypt_item.mapper_name = self.RD_MAPPER_NAME crypt_item.uses_cleartext_key = False self.disk_util.remove_crypt_item(crypt_item) # Remove old item in case it was already there self.disk_util.add_crypt_item_to_crypttab(crypt_item, self.passphrase_filename) self.add_to_fstab() def automount(self): """ Mount the resource disk (encrypted or not) or encrypt the resource disk and mount it if enable was called with EFA If False is returned, the resource disk is not mounted. """ # try to remount if the disk was previously encrypted and is still valid if self.try_remount(): return True # unencrypted or unusable elif self._is_encrypt_format_all(): return self.encrypt_format_mount() else: self.logger.log('EncryptionFormatAll not in use, resource disk will not be automatically formatted and encrypted.') return self._is_crypt_mounted() or self._is_plain_mounted() ================================================ FILE: VMEncryption/main/SupportedOS.json ================================================ { "redhat": [ { "Version" : "7.7" }, { "Version" : "7.6" }, { "Version" : "7.5" }, { "Version" : "7.4" }, { "Version" : "7.3" }, { "Version" : "7.2" }, { "Version" : "6.8" } ], "Ubuntu" : [ { "Version" : "16.04" }, { "Version" : "18.04" }, { "Version" : "14.04", "Kernel": "4.15" } ], "centos" : [ { "Version" : "7.7" }, { "Version" : "7.6" }, { "Version" : "7.5" }, { "Version" : "7.4" }, { "Version" : "7.3.1611" }, { "Version" : "7.2.1511" }, { "Version" : "6.9" }, { "Version" : "6.8" } ] } ================================================ FILE: VMEncryption/main/TokenUtil.py ================================================ # Copyright (C) Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import adal import json import sys import base64 def get_key(filename): with open(filename, 'r') as key_file: private_key = key_file.read() return private_key try: with open(sys.argv[1]) as json_file: d = json.load(json_file) key = get_key(d['certificate']) context = adal.AuthenticationContext(d['auth']) token = context.acquire_token_with_client_certificate(d['resource'],d['client'],key,d['thumbprint']) if token and 'accessToken' in token: print(token['accessToken']) except: exit(1) ================================================ FILE: VMEncryption/main/TransactionalCopyTask.py ================================================ #!/usr/bin/env python # # VMEncryption extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import os import os.path import sys import shlex from subprocess import * from CommandExecutor import CommandExecutor from Common import CommonVariables from ConfigUtil import ConfigUtil from OnGoingItemConfig import * class TransactionalCopyTask(object): """ copy_total_size is in byte, skip_target_size is also in byte slice_size is in byte 50M """ def __init__(self, logger, hutil, disk_util, ongoing_item_config, patching, encryption_environment, status_prefix=''): """ copy_total_size is in bytes. """ self.command_executer = CommandExecutor(logger) self.ongoing_item_config = ongoing_item_config self.total_size = self.ongoing_item_config.get_current_total_copy_size() self.block_size = self.ongoing_item_config.get_current_block_size() self.source_dev_full_path = self.ongoing_item_config.get_current_source_path() self.destination = self.ongoing_item_config.get_current_destination() self.current_slice_index = self.ongoing_item_config.get_current_slice_index() self.from_end = self.ongoing_item_config.get_from_end() self.last_slice_size = self.total_size % self.block_size # we add 1 even the last_slice_size is zero. self.total_slice_size = ((self.total_size - self.last_slice_size) / self.block_size) + 1 self.status_prefix = status_prefix self.encryption_environment = encryption_environment self.logger = logger self.patching = patching self.disk_util = disk_util self.hutil = hutil self.tmpfs_mount_point = "/mnt/azure_encrypt_tmpfs" self.slice_file_path = self.tmpfs_mount_point + "/slice_file" self.copy_command = self.patching.dd_path def resume_copy_internal(self, copy_slice_item_backup_file_size, skip_block, original_total_copy_size): block_size_of_slice_item_backup = 512 #copy the left slice if copy_slice_item_backup_file_size <= original_total_copy_size: skip_of_slice_item_backup_file = copy_slice_item_backup_file_size / block_size_of_slice_item_backup left_count = ((original_total_copy_size - copy_slice_item_backup_file_size) / block_size_of_slice_item_backup) total_count = original_total_copy_size / block_size_of_slice_item_backup original_device_skip_count = (self.block_size * skip_block) / block_size_of_slice_item_backup if left_count != 0: dd_cmd = str(self.copy_command) \ + ' if=' + self.source_dev_full_path \ + ' of=' + self.encryption_environment.copy_slice_item_backup_file \ + ' bs=' + str(block_size_of_slice_item_backup) \ + ' skip=' + str(original_device_skip_count + skip_of_slice_item_backup_file) \ + ' seek=' + str(skip_of_slice_item_backup_file) \ + ' count=' + str(left_count) return_code = self.command_executer.Execute(dd_cmd) if return_code != CommonVariables.process_success: return return_code dd_cmd = str(self.copy_command) \ + ' if=' + self.encryption_environment.copy_slice_item_backup_file \ + ' of=' + self.destination \ + ' bs=' + str(block_size_of_slice_item_backup) \ + ' seek=' + str(original_device_skip_count) \ + ' count=' + str(total_count) return_code = self.command_executer.Execute(dd_cmd) if return_code != CommonVariables.process_success: return return_code else: self.current_slice_index += 1 self.ongoing_item_config.current_slice_index = self.current_slice_index self.ongoing_item_config.commit() if os.path.exists(self.encryption_environment.copy_slice_item_backup_file): os.remove(self.encryption_environment.copy_slice_item_backup_file) return return_code else: self.logger.log(msg="copy_slice_item_backup_file_size is bigger than original_total_copy_size", level=CommonVariables.ErrorLevel) return CommonVariables.backup_slice_file_error def resume_copy(self): if self.from_end.lower() == 'true': skip_block = (self.total_slice_size - self.current_slice_index - 1) else: skip_block = self.current_slice_index return_code = CommonVariables.process_success if self.current_slice_index == 0: if self.last_slice_size > 0: if os.path.exists(self.encryption_environment.copy_slice_item_backup_file): copy_slice_item_backup_file_size = os.path.getsize(self.encryption_environment.copy_slice_item_backup_file) return_code = self.resume_copy_internal(copy_slice_item_backup_file_size=copy_slice_item_backup_file_size, skip_block=skip_block, original_total_copy_size=self.last_slice_size) else: self.logger.log(msg="1. the slice item backup file not exists.", level=CommonVariables.WarningLevel) else: self.logger.log(msg="the last slice", level=CommonVariables.WarningLevel) else: if os.path.exists(self.encryption_environment.copy_slice_item_backup_file): copy_slice_item_backup_file_size = os.path.getsize(self.encryption_environment.copy_slice_item_backup_file) return_code = self.resume_copy_internal(copy_slice_item_backup_file_size, skip_block=skip_block, original_total_copy_size=self.block_size) else: self.logger.log(msg="2. unfortunately the slice item backup file not exists.", level=CommonVariables.WarningLevel) return return_code def copy_last_slice(self, skip_block): block_size_of_last_slice = 512 skip_of_last_slice = (skip_block * self.block_size) / block_size_of_last_slice count_of_last_slice = self.last_slice_size / block_size_of_last_slice copy_result = self.copy_internal(from_device=self.source_dev_full_path, to_device = self.destination, skip=skip_of_last_slice, seek=skip_of_last_slice, block_size=block_size_of_last_slice, count = count_of_last_slice) return copy_result def begin_copy(self): """ check the device_item size first, cut it """ self.resume_copy() if self.from_end.lower() == 'true': while self.current_slice_index < self.total_slice_size: skip_block = (self.total_slice_size - self.current_slice_index - 1) if self.current_slice_index == 0: if self.last_slice_size > 0: copy_result = self.copy_last_slice(skip_block) if copy_result != CommonVariables.process_success: return copy_result else: self.logger.log(msg = "the last slice size is zero, so skip the 0 index.") else: copy_result = self.copy_internal(from_device=self.source_dev_full_path, to_device=self.destination, skip=skip_block, seek=skip_block, block_size=self.block_size) if copy_result != CommonVariables.process_success: return copy_result self.current_slice_index += 1 if self.status_prefix: msg = self.status_prefix + ': ' \ + str(int(self.current_slice_index / (float)(self.total_slice_size) * 100.0)) \ + '%' self.hutil.do_status_report(operation='DataCopy', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) self.ongoing_item_config.current_slice_index = self.current_slice_index self.ongoing_item_config.commit() return CommonVariables.process_success else: while self.current_slice_index < self.total_slice_size: skip_block = self.current_slice_index if self.current_slice_index == (self.total_slice_size - 1): if self.last_slice_size > 0: copy_result = self.copy_last_slice(skip_block) if copy_result != CommonVariables.process_success: return copy_result else: self.logger.log(msg = "the last slice size is zero, so skip the last slice index.") else: copy_result = self.copy_internal(from_device=self.source_dev_full_path, to_device=self.destination, skip=skip_block, seek=skip_block, block_size=self.block_size) if copy_result != CommonVariables.process_success: return copy_result self.current_slice_index += 1 if self.status_prefix: msg = self.status_prefix + ': ' \ + str(int(self.current_slice_index / (float)(self.total_slice_size) * 100.0)) \ + '%' self.hutil.do_status_report(operation='DataCopy', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) self.hutil.do_status_report(operation='DataCopy', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) self.ongoing_item_config.current_slice_index = self.current_slice_index self.ongoing_item_config.commit() return CommonVariables.process_success """ TODO: if the copy failed? """ def copy_internal(self, from_device, to_device, block_size, skip=0, seek=0, count=1): """ first, copy the data to the middle cache """ dd_cmd = str(self.copy_command) \ + ' if=' + from_device \ + ' of=' + self.slice_file_path \ + ' bs=' + str(block_size) \ + ' skip=' + str(skip) \ + ' count=' + str(count) return_code = self.command_executer.Execute(dd_cmd) if return_code != CommonVariables.process_success: self.logger.log(msg="{0} is {1}".format(dd_cmd, return_code), level=CommonVariables.ErrorLevel) return return_code else: slice_file_size = os.path.getsize(self.slice_file_path) self.logger.log(msg=("slice_file_size is: {0}".format(slice_file_size))) """ second, copy the data in the middle cache to the backup slice. """ backup_slice_item_cmd = str(self.copy_command) \ + ' if=' + self.slice_file_path \ + ' of=' + self.encryption_environment.copy_slice_item_backup_file \ + ' bs=' + str(block_size) \ + ' count=' + str(count) backup_slice_args = shlex.split(backup_slice_item_cmd) backup_process = Popen(backup_slice_args) self.logger.log("backup_slice_item_cmd is:{0}".format(backup_slice_item_cmd)) """ third, copy the data in the middle cache to the target device. """ dd_cmd = str(self.copy_command) + ' if=' + self.slice_file_path + ' of=' + to_device + ' bs=' + str(block_size) + ' seek=' + str(seek) + ' count=' + str(count) return_code = self.command_executer.Execute(dd_cmd) if return_code != CommonVariables.process_success: self.logger.log(msg=("{0} is: {1}".format(dd_cmd, return_code)), level = CommonVariables.ErrorLevel) else: #the copy done correctly, so clear the backup slice file item. backup_process.kill() if os.path.exists(self.encryption_environment.copy_slice_item_backup_file): self.logger.log(msg = "clean up the backup file") os.remove(self.encryption_environment.copy_slice_item_backup_file) if os.path.exists(self.slice_file_path): self.logger.log(msg = "clean up the slice file") os.remove(self.slice_file_path) return return_code def prepare_mem_fs(self): self.disk_util.make_sure_path_exists(self.tmpfs_mount_point) commandToExecute = self.patching.mount_path + " -t tmpfs -o size=" + str(self.block_size + 1024) + " tmpfs " + self.tmpfs_mount_point self.logger.log("prepare mem fs script is: {0}".format(commandToExecute)) return_code = self.command_executer.Execute(commandToExecute) return return_code def clear_mem_fs(self): commandToExecute = self.patching.umount_path + " " + self.tmpfs_mount_point return_code = self.command_executer.Execute(commandToExecute) return return_code ================================================ FILE: VMEncryption/main/Utils/HandlerUtil.py ================================================ # # Handler library for Linux IaaS # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ JSON def: HandlerEnvironment.json [{ "name": "ExampleHandlerLinux", "seqNo": "seqNo", "version": "1.0", "handlerEnvironment": { "logFolder": "", "configFolder": "", "statusFolder": "", "heartbeatFile": "", } }] Example ./config/1.settings "{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"1BE9A13AA1321C7C515EF109746998BAB6D86FD1","protectedSettings": "MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgR+nhc6VHQTQpCiiV2zANBgkqhkiG9w0BAQEFAASCAQCKr09QKMGhwYe+O4/a8td+vpB4eTR+BQso84cV5KCAnD6iUIMcSYTrn9aveY6v6ykRLEw8GRKfri2d6tvVDggUrBqDwIgzejGTlCstcMJItWa8Je8gHZVSDfoN80AEOTws9Fp+wNXAbSuMJNb8EnpkpvigAWU2v6pGLEFvSKC0MCjDTkjpjqciGMcbe/r85RG3Zo21HLl0xNOpjDs/qqikc/ri43Y76E/Xv1vBSHEGMFprPy/Hwo3PqZCnulcbVzNnaXN3qi/kxV897xGMPPC3IrO7Nc++AT9qRLFI0841JLcLTlnoVG1okPzK9w6ttksDQmKBSHt3mfYV+skqs+EOMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQITgu0Nu3iFPuAGD6/QzKdtrnCI5425fIUy7LtpXJGmpWDUA==","publicSettings":{"port":"3000"}}}]}" Example HeartBeat { "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } } Example Status Report: [{"version":"1.0","timestampUTC":"2014-05-29T04:20:13Z","status":{"name":"Chef Extension Handler","operation":"chef-client-run","status":"success","code":0,"formattedMessage":{"lang":"en-US","message":"Chef-client run success"}}}] """ import fnmatch import glob import os import os.path import re import shutil import string import subprocess import sys import imp import base64 import json import tempfile import time from Common import * from os.path import join from Utils.WAAgentUtil import waagent from waagent import LoggerInit import logging import logging.handlers DateTimeFormat = "%Y-%m-%dT%H:%M:%SZ" class HandlerContext: def __init__(self, name): self._name = name self._version = '0.0' return class HandlerUtility: def __init__(self, log, error, short_name): self._log = log self._error = error self._short_name = short_name self.patching = None self.disk_util = None self.find_last_nonquery_operation = False self.config_archive_folder = '/var/lib/azure_disk_encryption_archive' self._context = HandlerContext(self._short_name) def _get_log_prefix(self): return '[%s-%s]' % (self._context._name, self._context._version) def _get_current_seq_no(self, config_folder): seq_no = -1 cur_seq_no = -1 freshest_time = None for subdir, dirs, files in os.walk(config_folder): for file in files: try: if file.endswith('.settings'): cur_seq_no = int(os.path.basename(file).split('.')[0]) if freshest_time == None: freshest_time = os.path.getmtime(join(config_folder, file)) seq_no = cur_seq_no else: current_file_m_time = os.path.getmtime(join(config_folder, file)) if current_file_m_time > freshest_time: freshest_time = current_file_m_time seq_no = cur_seq_no except ValueError: continue if seq_no < 0: # guest agent is expected to provide at least one settings file to extension self.error("unable to get current sequence number from config folder") return seq_no def get_last_seq(self): if os.path.isfile('mrseq'): seq = waagent.GetFileContents('mrseq') if seq: return int(seq) return -1 def get_latest_seq(self): settings_files = glob.glob(os.path.join(self._context._config_dir, '*.settings')) settings_files = [os.path.basename(f) for f in settings_files] seq_nums = [int(re.findall(r'(\d+)\.settings', f)[0]) for f in settings_files] if seq_nums: return max(seq_nums) else: # guest agent is expected to provide at least one settings file to the extension self.log("unable to get latest sequence number from config folder") return -1 def get_current_seq(self): return int(self._context._seq_no) def same_seq_as_last_run(self): return self.get_current_seq() == self.get_last_seq() def exit_if_same_seq(self, exit_status=None): current_seq = int(self._context._seq_no) last_seq = self.get_last_seq() if current_seq == last_seq: self.log("the sequence numbers are same, so skipping daemon"+ ", current=" + str(current_seq) + ", last=" + str(last_seq)) if exit_status: self.do_status_report(exit_status['operation'], exit_status['status'], exit_status['status_code'], exit_status['message']) sys.exit(0) def log(self, message): # write message to stderr for inclusion in QOS telemetry sys.stderr.write(message) self._log(self._get_log_prefix() + ': ' + message) def error(self, message): # write message to stderr for inclusion in QOS telemetry sys.stderr.write(message) self._error(self._get_log_prefix() + ': ' + message) def _parse_config(self, config_txt): # pre : config_txt is a text string containing JSON configuration settings # post: handlerSettings is initialized with these settings and the config # object is returned. If an error occurs, None is returned. if not config_txt: self.error('empty config, nothing to parse') return None config = None try: config = json.loads(config_txt) except: self.error('invalid config, could not parse: ' + str(config_txt)) if config: handlerSettings = config['runtimeSettings'][0]['handlerSettings'] # skip unnecessary decryption of protected settings for query status # operations, to avoid timeouts in case of multiple settings files if handlerSettings.has_key('publicSettings'): ps = handlerSettings.get('publicSettings') op = ps.get(CommonVariables.EncryptionEncryptionOperationKey) if op == CommonVariables.QueryEncryptionStatus: return config if handlerSettings.has_key('protectedSettings') and \ handlerSettings.has_key("protectedSettingsCertThumbprint") and \ handlerSettings['protectedSettings'] is not None and \ handlerSettings["protectedSettingsCertThumbprint"] is not None: thumb = handlerSettings['protectedSettingsCertThumbprint'] cert = waagent.LibDir + '/' + thumb + '.crt' pkey = waagent.LibDir + '/' + thumb + '.prv' f = tempfile.NamedTemporaryFile(delete=False) f.close() waagent.SetFileContents(f.name, config['runtimeSettings'][0]['handlerSettings']['protectedSettings']) cleartxt = None cleartxt = waagent.RunGetOutput(self.patching.base64_path + " -d " + f.name + " | " + self.patching.openssl_path + " smime -inform DER -decrypt -recip " + cert + " -inkey " + pkey)[1] if cleartxt == None: self.error("OpenSSh decode error using thumbprint " + thumb) self.do_exit(1, self.operation,'error','1', self.operation + ' Failed') jctxt = '' try: jctxt = json.loads(cleartxt) except: self.error('JSON exception loading protected settings') handlerSettings['protectedSettings'] = jctxt return config def do_parse_context(self, operation): self.operation = operation _context = self.try_parse_context() if not _context: self.log("no settings file found") self.do_exit(0, 'QueryEncryptionStatus', CommonVariables.extension_success_status, str(CommonVariables.success), 'No operation found, find_last_nonquery_operation={0}'.format(self.find_last_nonquery_operation)) return _context def is_valid_nonquery(self, settings_file_path): # note: the nonquery operations list includes update and disable nonquery_ops = [ CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll, CommonVariables.UpdateEncryptionSettings, CommonVariables.DisableEncryption ] if settings_file_path and os.path.exists(settings_file_path): # open file and look for presence of nonquery operation config_txt = waagent.GetFileContents(settings_file_path) config_obj = self._parse_config(config_txt) public_settings_str = config_obj['runtimeSettings'][0]['handlerSettings'].get('publicSettings') # if not json already, load string as json if isinstance(public_settings_str, basestring): public_settings = json.loads(public_settings_str) else: public_settings = public_settings_str operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if operation and (operation in nonquery_ops): return True # invalid input, or not recognized as a valid nonquery operation return False def get_last_nonquery_config_path(self): # pre: internal self._context._config_dir and _seq_no, _settings_file must be set prior to call # post: returns path to last nonquery settings file in current config, archived folder, or None # validate that internal preconditions are satisfied and internal variables are initialized if self._context._seq_no < 0: self.error("current context sequence number must be initialized and non-negative") if not self._context._config_dir or not os.path.isdir(self._context._config_dir): self.error("current context config dir must be initialized and point to a path that exists") if not self._context._settings_file or not os.path.exists(self._context._settings_file): self.error("current context settings file variable must be initialized and point to a file that exists") # check timestamp of pointer to last archived settings file curr_path = self._context._settings_file last_path = os.path.join(self.config_archive_folder, "lnq.settings") # if an archived nonquery settings file exists, use it if no current settings file exists, or it is newer than current settings if os.path.exists(last_path) and ((not os.path.exists(curr_path)) or (os.path.exists(curr_path) and (os.stat(last_path).st_mtime > os.stat(curr_path).st_mtime))): return last_path else: # reverse iterate through numbered settings files in config dir # and return path to the first nonquery settings file found for i in range(self._context._seq_no,-1,-1): curr_path = os.path.join(self._context._config_dir, str(i) + '.settings') if self.is_valid_nonquery(curr_path): return curr_path # nothing was found in the current config settings, check the archived settings if os.path.exists(last_path): return last_path else: if os.path.exists(self.config_archive_folder): # walk through any archived [n].settings files found in archived settings folder # sorted by reverse timestamp (processing newest to oldest) until a nonquery settings file found files = sorted(os.listdir(self.config_archive_folder), key=os.path.getctime, reverse=True) for f in files: curr_path = os.path.join(self._context._config_dir, f) # TODO: check that file name matches the [n].settings format if self.is_valid_nonquery(curr_path): # found, copy to last_nonquery_settings in archived settings return curr_path # unable to find any nonquery settings file return None def get_last_config(self, nonquery): # precondition: self._context._config_dir, self._context._seq_no are already set and valid # postcondition: a configuration object from the last configuration settings file is returned # if nonquery flag is true, search for the last settings file that was not a query status operation # if nonquery is false, return the current settings file if nonquery: last_config_path = self.get_last_nonquery_config_path() else: # retrieve the settings file corresponding to the current sequence number last_config_path = os.path.join(self._context._config_dir, str(self._context._seq_no) + '.settings') # if not found, attempt to fall back to an archived settings file if not os.path.isfile(last_config_path): self.log('settings file not found, checking for archived settings') last_config_path = os.path.join(self.config_archive_folder, "lnq.settings") if not os.path.isfile(last_config_path): self.error('archived settings file not found, unable to get last config') return None # settings file was found, parse config and return config object config_txt = waagent.GetFileContents(last_config_path) if not config_txt: self.error('configuration settings empty, unable to get last config') return None config_obj = self._parse_config(config_txt) if not config_obj: self.error('failed to parse configuration settings, unable to get last config') return None else: return config_obj def get_handler_env(self): # load environment variables from HandlerEnvironment.json # according to spec, it is always in the ./ directory #self.log('cwd is ' + os.path.realpath(os.path.curdir)) handler_env_file = './HandlerEnvironment.json' if not os.path.isfile(handler_env_file): self.error("Unable to locate " + handler_env_file) return None handler_env_json_str = waagent.GetFileContents(handler_env_file) if handler_env_json_str == None : self.error("Unable to read " + handler_env_file) try: handler_env = json.loads(handler_env_json_str) except: pass if handler_env == None : # TODO - treat this as a telemetry error indicating an agent bug, as this file should always be available and readable self.log("JSON error processing " + str(handler_env_file)) return None if type(handler_env) == list: handler_env = handler_env[0] return handler_env def try_parse_context(self): # precondition: agent is in a properly running state with at least one settings file in config folder # any archived settings from prior instances of the extension were saved to archive folder # postcondition: context variables initialized to reflect current handler environment and prior call history # initialize handler environment context variables handler_env = self.get_handler_env() self._context._name = handler_env['name'] self._context._version = str(handler_env['version']) self._context._config_dir = handler_env['handlerEnvironment']['configFolder'] self._context._log_dir = handler_env['handlerEnvironment']['logFolder'] self._context._log_file = os.path.join(handler_env['handlerEnvironment']['logFolder'],'extension.log') self._change_log_file() self._context._status_dir = handler_env['handlerEnvironment']['statusFolder'] self._context._heartbeat_file = handler_env['handlerEnvironment']['heartbeatFile'] # initialize the current sequence number corresponding to settings files in config folder self._context._seq_no = self._get_current_seq_no(self._context._config_dir) self._context._settings_file = os.path.join(self._context._config_dir, str(self._context._seq_no) + '.settings') # get a config object corresponding to the last settings file, skipping QueryEncryptionStatus settings # files when find_last_nonquery_operation is True, falling back to archived settings if necessary # note - in the case of nonquery settings file retrieval, when preceded by one or more query settings # file that are more recent, the config object will not match the active settings file or sequence number self._context._config = self.get_last_config(self.find_last_nonquery_operation) return self._context def _change_log_file(self): #self.log("Change log file to " + self._context._log_file) LoggerInit(self._context._log_file,'/dev/stdout') self._log = waagent.Log self._error = waagent.Error def save_seq(self): self.set_last_seq(self._context._seq_no) self.log("set most recent sequence number to " + str(self._context._seq_no)) def set_last_seq(self, seq): waagent.SetFileContents('mrseq', str(seq)) def redo_last_status(self): latest_sequence_num = self.get_latest_seq() if (latest_sequence_num > 0): latest_seq = str(latest_sequence_num) self._context._status_file = os.path.join(self._context._status_dir, latest_seq + '.status') previous_seq = str(latest_sequence_num - 1) previous_status_file = os.path.join(self._context._status_dir, previous_seq + '.status') shutil.copy2(previous_status_file, self._context._status_file) self.log("[StatusReport ({0})] Copied {1} to {2}".format(latest_seq, previous_status_file, self._context._status_file)) else: self.log("unable to redo last status, no prior status found") def redo_current_status(self): stat_rept = waagent.GetFileContents(self._context._status_file) stat = json.loads(stat_rept) self.do_status_report(stat[0]["status"]["operation"], stat[0]["status"]["status"], stat[0]["status"]["code"], stat[0]["status"]["formattedMessage"]["message"]) def do_status_report(self, operation, status, status_code, message): latest_seq_num = self.get_latest_seq() if (latest_seq_num >= 0): latest_seq = str(self.get_latest_seq()) else: self.log("sequence number could not be derived from settings files, using 0.status") latest_seq = "0" self._context._status_file = os.path.join(self._context._status_dir, latest_seq + '.status') if message is None: message = "" message = filter(lambda c: c in string.printable, message) message = message.encode('ascii', 'ignore') self.log("[StatusReport ({0})] op: {1}".format(latest_seq, operation)) self.log("[StatusReport ({0})] status: {1}".format(latest_seq, status)) self.log("[StatusReport ({0})] code: {1}".format(latest_seq, status_code)) self.log("[StatusReport ({0})] msg: {1}".format(latest_seq, message)) tstamp = time.strftime(DateTimeFormat, time.gmtime()) stat = [{ "version" : self._context._version, "timestampUTC" : tstamp, "status" : { "name" : self._context._name, "operation" : operation, "status" : status, "code" : status_code, "formattedMessage" : { "lang" : "en-US", "message" : message } } }] if self.disk_util: encryption_status = self.disk_util.get_encryption_status() encryption_status_dict = json.loads(encryption_status) self.log("[StatusReport ({0})] substatus : OS : {1} Data : {2}".format(latest_seq, encryption_status_dict['os'], encryption_status_dict['data'])) substat = [{ "name" : self._context._name, "operation" : operation, "status" : status, "code" : status_code, "formattedMessage" : { "lang" : "en-US", "message" : encryption_status } }] stat[0]["status"]["substatus"] = substat if "VMRestartPending" in encryption_status: stat[0]["status"]["formattedMessage"]["message"] = "OS disk successfully encrypted, please reboot the VM" stat_rept = json.dumps(stat) # rename all other status files, or the WALA would report the wrong # status file. # because the wala choose the status file with the highest sequence # number to report. if self._context._status_file: with open(self._context._status_file,'w+') as f: f.write(stat_rept) def backup_settings_status_file(self, _seq_no): self.log("current seq no is " + _seq_no) for subdir, dirs, files in os.walk(self._context._config_dir): for file in files: try: if file.endswith('.settings') and file != (_seq_no + ".settings"): new_file_name = file.replace(".","_") os.rename(join(self._context._config_dir, file), join(self._context._config_dir, new_file_name)) except: self.log("failed to rename the settings file.") def do_exit(self, exit_code, operation, status, code, message): try: self.do_status_report(operation, status, code, message) except Exception as e: self.log("Can't update status: " + str(e)) if message: # Remove newline character so that msg is printed in one line strip_msg = message.replace('\n', ' ') self.log("Exited with message {0}".format(strip_msg)) sys.exit(exit_code) def get_handler_settings(self): return self._context._config['runtimeSettings'][0]['handlerSettings'] def get_protected_settings(self): return self.get_handler_settings().get('protectedSettings') def get_public_settings(self): return self.get_handler_settings().get('publicSettings') def archive_old_configs(self): if not os.path.exists(self.config_archive_folder): os.makedirs(self.config_archive_folder) # only persist latest nonquery settings file to archived settings # and prevent the accumulation of large numbers of obsolete files src = self.get_last_nonquery_config_path() if src: dest = os.path.join(self.config_archive_folder, 'lnq.settings') if src != dest: shutil.copy2(src,dest) ================================================ FILE: VMEncryption/main/Utils/WAAgentUtil.py ================================================ # Wrapper module for waagent # # waagent is not written as a module. This wrapper module is created # to use the waagent code as a module. # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import imp import os import os.path # # The following code will search and load waagent code and expose # it as a submodule of current module # def searchWAAgent(): agentPath = '/usr/sbin/waagent' if os.path.isfile(agentPath): return agentPath user_paths = os.environ['PYTHONPATH'].split(os.pathsep) for user_path in user_paths: agentPath = os.path.join(user_path, 'waagent') if os.path.isfile(agentPath): return agentPath return None agentPath = searchWAAgent() if agentPath: waagent = imp.load_source('waagent', agentPath) else: raise Exception("Can't load waagent.") if not hasattr(waagent, "AddExtensionEvent"): """ If AddExtensionEvent is not defined, provide a dummy impl. """ def _AddExtensionEvent(*args, **kwargs): pass waagent.AddExtensionEvent = _AddExtensionEvent if not hasattr(waagent, "WALAEventOperation"): class _WALAEventOperation: HeartBeat="HeartBeat" Provision = "Provision" Install = "Install" UnIsntall = "UnInstall" Disable = "Disable" Enable = "Enable" Download = "Download" Upgrade = "Upgrade" Update = "Update" waagent.WALAEventOperation = _WALAEventOperation ================================================ FILE: VMEncryption/main/Utils/__init__.py ================================================ # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMEncryption/main/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMEncryption/main/check_util.py ================================================ #!/usr/bin/env python # # ********************************************************* # Copyright (c) Microsoft. All rights reserved. # # Apache 2.0 License # # You may obtain a copy of the License at # http:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ********************************************************* """This module checks validity of the environment prior to disk encryption""" import os import os.path import urlparse import re import json from Common import CommonVariables from CommandExecutor import CommandExecutor from distutils.version import LooseVersion class CheckUtil(object): """Checks compatibility for disk encryption""" def __init__(self, logger): self.logger = logger def is_app_compat_issue_detected(self): """check for the existence of applications that enable is not yet compatible with""" detected = False dirs = ['./usr/sap'] files = ['/etc/init.d/mongodb', '/etc/init.d/cassandra', '/etc/init.d/docker', '/opt/Symantec/symantec_antivirus'] for testdir in dirs: if os.path.isdir(testdir): self.logger.log('WARNING: likely app compat issue [' + testdir + ']') detected = True for testfile in files: if os.path.isfile(testfile): self.logger.log('WARNING: likely app compat issue [' + testfile + ']') detected = True return detected def is_insufficient_memory(self): """check if memory total is greater than or equal to the recommended minimum size""" minsize = 7000000 memtotal = int(os.popen("grep MemTotal /proc/meminfo | grep -o -E [0-9]+").read()) if memtotal < minsize: self.logger.log('WARNING: total memory [' + str(memtotal) + 'kb] is less than 7GB') return True return False def is_unsupported_mount_scheme(self): """ check for data disks mounted under /mnt and for recursively mounted data disks such as /mnt/data1, /mnt/data2, or /data3 + /data3/data4 """ detected = False ignorelist = ['/', '/dev', '/proc', '/run', '/sys', '/sys/fs/cgroup'] mounts = [] with open('/proc/mounts') as infile: for line in infile: mountpoint = line.split()[1] if mountpoint not in ignorelist: mounts.append(line.split()[1]) for mnt1 in mounts: for mnt2 in mounts: if (mnt1 != mnt2) and (mnt2.startswith(mnt1)): self.logger.log('WARNING: unsupported mount scheme [' + mnt1 + ' ' + mnt2 + ']') detected = True return detected def check_kv_url(self, test_url, message): """basic sanity check of the key vault url""" if test_url is None: raise Exception(message + '\nNo URL supplied') try: parse_result = urlparse.urlparse(test_url) except: raise Exception(message + '\nMalformed URL: ' + test_url) if not parse_result.scheme.lower() == "https" : raise Exception('\n' + message + '\n URL should be https: ' + test_url + "\n") if not parse_result.netloc: raise Exception(message + '\nMalformed URL: ' + test_url) # Don't bother with explicit dns check, the host already does and should start returning better error messages. # dns_suffix_list = ["vault.azure.net", "vault.azure.cn", "vault.usgovcloudapi.net", "vault.microsoftazure.de"] # Add new suffixes here when a new national cloud is introduced. # Relevant link: https://docs.microsoft.com/en-us/azure/key-vault/key-vault-access-behind-firewall#key-vault-operations # dns_match = False # for dns_suffix in dns_suffix_list: # escaped_dns_suffix = dns_suffix.replace(".","\.") # if re.match('[a-zA-Z0-9\-]+\.' + escaped_dns_suffix + '(:443)?$', parse_result.netloc): # # matched a valid dns, set matched to true # dns_match = True # if not dns_match: # raise Exception('\n' + message + '\nProvided URL does not match known valid URL formats: ' + \ # "\n\tProvided URL: " + test_url + \ # "\n\tKnown valid formats:\n\t\t" + \ # "\n\t\t".join(["https://." + dns_suffix + "/" for dns_suffix in dns_suffix_list]) ) return def validate_key_vault_params(self, public_settings): encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation not in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll]: # No need to check the KV urls if its not an encryption operation return kek_url = public_settings.get(CommonVariables.KeyEncryptionKeyURLKey) kv_url = public_settings.get(CommonVariables.KeyVaultURLKey) kek_algorithm = public_settings.get(CommonVariables.KeyEncryptionAlgorithmKey) self.check_kv_url(kv_url, "Encountered an error while checking the Key Vault URL") if kek_url: self.check_kv_url(kek_url, "A KEK URL was specified, but was invalid") if kek_algorithm is None or kek_algorithm.lower() not in [algo.lower() for algo in CommonVariables.encryption_algorithms]: if kek_algorithm: raise Exception( "The KEK encryption algorithm requested was not recognized") else: self.logger.log( "No KEK algorithm specified will default to {0}".format( CommonVariables.default_encryption_algorithm)) def validate_volume_type(self, public_settings): encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation in [CommonVariables.QueryEncryptionStatus]: # No need to validate volume type for Query Encryption Status operation self.logger.log( "Ignore validating volume type for {0}".format( CommonVariables.QueryEncryptionStatus)) return volume_type = public_settings.get(CommonVariables.VolumeTypeKey) supported_types = CommonVariables.SupportedVolumeTypes if not volume_type.lower() in map(lambda x: x.lower(), supported_types) : raise Exception("Unknown Volume Type: {0}, has to be one of {1}".format(volume_type, supported_types)) def validate_lvm_os(self, public_settings): encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if not encryption_operation: self.logger.log("LVM OS validation skipped (no encryption operation)") return elif encryption_operation.lower() == CommonVariables.QueryEncryptionStatus.lower(): self.logger.log("LVM OS validation skipped (Encryption Operation: QueryEncryptionStatus)") return volume_type = public_settings.get(CommonVariables.VolumeTypeKey) if not volume_type: self.logger.log("LVM OS validation skipped (no volume type)") return elif volume_type.lower() == CommonVariables.VolumeTypeData.lower(): self.logger.log("LVM OS validation skipped (Volume Type: DATA)") return # run lvm check if volume type, encryption operation were specified and OS type is LVM detected = False # first, check if the root OS volume type is LVM if ( encryption_operation and volume_type and os.system("lsblk -o TYPE,MOUNTPOINT | grep lvm | grep -q '/$'") == 0): # next, check that all required logical volume names exist ( swaplv is not required ) lvlist = ['rootvg-tmplv', 'rootvg-usrlv', 'rootvg-optlv', 'rootvg-homelv', 'rootvg-varlv', 'rootvg-rootlv'] for lvname in lvlist: if not os.system("lsblk -o NAME | grep -q '" + lvname + "'") == 0: self.logger.log('LVM OS scheme is missing LV [' + lvname + ']') detected = True if detected: raise Exception("LVM OS disk layout does not satisfy prerequisites ( see https://aka.ms/adelvm )") def validate_vfat(self): """ Check for vfat module using modprobe and raise exception if not found """ try: executor = CommandExecutor(self.logger) executor.Execute("modprobe vfat", True) except: raise RuntimeError('Incompatible system, prerequisite vfat module was not found.') def validate_aad(self, public_settings): encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation not in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll]: # skip if not an encryption operation, valid aad client id is only needed for encryption operations return aad_client_id = public_settings.get(CommonVariables.AADClientIDKey) uuid_pattern = r"^([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}){1}$" if aad_client_id: if not re.match(uuid_pattern, aad_client_id, re.IGNORECASE): message = 'AADClientID value is missing or invalid.' # provide an extra hint if Unicode curly quotes were pasted in if (u'\u201c' in aad_client_id) or (u'\u201d' in aad_client_id): message += ' Please remove Unicode quotation marks.' raise Exception(message + '\nActual Value: [' + aad_client_id + ']\nExpected Format: [nnnnnnnn-nnnn-nnnn-nnnn-nnnnnnnnnnnn]') else: raise Exception(CommonVariables.AADClientIDKey + ' property was not found in settings') def validate_memory_os_encryption(self, public_settings, encryption_status): is_enable_operation = False encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll]: is_enable_operation = True volume_type = public_settings.get(CommonVariables.VolumeTypeKey) if is_enable_operation and not volume_type.lower() == CommonVariables.VolumeTypeData.lower() and encryption_status["os"] == "NotEncrypted": if self.is_insufficient_memory(): raise Exception("Not enough memory for enabling encryption on OS volume. 8 GB memory is recommended.") def is_supported_os(self, public_settings, DistroPatcher, encryption_status): encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation in [CommonVariables.QueryEncryptionStatus]: self.logger.log("Query encryption operation detected. Skipping OS encryption validation check.") return volume_type = public_settings.get(CommonVariables.VolumeTypeKey) # If volume type is data allow the operation (At this point we are sure a patch file for the distro exist) if volume_type.lower() == CommonVariables.VolumeTypeData.lower(): self.logger.log("Volume Type is DATA. Skipping OS encryption validation check.") return # If OS volume is already encrypted just return (Should not break already encryted VM's) if encryption_status["os"] != "NotEncrypted": self.logger.log("OS volume already encrypted. Skipping OS encryption validation check.") return distro_name = DistroPatcher.distro_info[0] distro_version = DistroPatcher.distro_info[1] supported_os_file = os.path.join(os.getcwd(), 'main/SupportedOS.json') with open(supported_os_file) as json_file: data = json.load(json_file) if distro_name in data: versions = data[distro_name] for version in versions: if distro_version.startswith(version['Version']): if 'Kernel' in version and LooseVersion(DistroPatcher.kernel_version) < LooseVersion(version['Kernel']): raise Exception('Kernel version {0} is not supported. Upgrade to kernel version {1}'.format(DistroPatcher.kernel_version, version['Kernel'])) else: return raise Exception('Distro {0} {1} is not supported for OS encryption'.format(distro_name, distro_version)) def precheck_for_fatal_failures(self, public_settings, encryption_status, DistroPatcher): """ run all fatal prechecks, they should throw an exception if anything is wrong """ self.validate_key_vault_params(public_settings) self.validate_volume_type(public_settings) self.validate_lvm_os(public_settings) self.validate_vfat() self.validate_aad(public_settings) self.validate_memory_os_encryption(public_settings, encryption_status) self.is_supported_os(public_settings, DistroPatcher, encryption_status) def is_non_fatal_precheck_failure(self): """ run all prechecks """ detected = False if self.is_app_compat_issue_detected(): detected = True self.logger.log("PRECHECK: Likely app compat issue detected") if self.is_insufficient_memory(): detected = True self.logger.log("PRECHECK: Low memory condition detected") if self.is_unsupported_mount_scheme(): detected = True self.logger.log("PRECHECK: Unsupported mount scheme detected") return detected ================================================ FILE: VMEncryption/main/handle.py ================================================ #!/usr/bin/env python # # Azure Disk Encryption For Linux Extension # # Copyright 2019 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import filecmp import json import os import os.path import re import subprocess import sys import time import tempfile import traceback import uuid import shutil from Utils import HandlerUtil from Common import CommonVariables, CryptItem from ExtensionParameter import ExtensionParameter from DiskUtil import DiskUtil from ResourceDiskUtil import ResourceDiskUtil from BackupLogger import BackupLogger from KeyVaultUtil import KeyVaultUtil from EncryptionConfig import EncryptionConfig from patch import GetDistroPatcher from BekUtil import BekUtil from check_util import CheckUtil from DecryptionMarkConfig import DecryptionMarkConfig from EncryptionMarkConfig import EncryptionMarkConfig from EncryptionEnvironment import EncryptionEnvironment from OnGoingItemConfig import OnGoingItemConfig from ProcessLock import ProcessLock from CommandExecutor import CommandExecutor, ProcessCommunicator from __builtin__ import int def install(): hutil.do_parse_context('Install') hutil.do_exit(0, 'Install', CommonVariables.extension_success_status, str(CommonVariables.success), 'Install Succeeded') def disable(): hutil.do_parse_context('Disable') # Archive configs at disable to make them available to new extension version prior to update # The extension update handshake is [old:disable][new:update][old:uninstall][new:install] hutil.archive_old_configs() hutil.do_exit(0, 'Disable', CommonVariables.extension_success_status, '0', 'Disable succeeded') def uninstall(): hutil.do_parse_context('Uninstall') hutil.do_exit(0, 'Uninstall', CommonVariables.extension_success_status, '0', 'Uninstall succeeded') def disable_encryption(): hutil.do_parse_context('DisableEncryption') logger.log('Disabling encryption') decryption_marker = DecryptionMarkConfig(logger, encryption_environment) if decryption_marker.config_file_exists(): logger.log(msg="decryption is marked, starting daemon.", level=CommonVariables.InfoLevel) start_daemon('DisableEncryption') hutil.do_exit(exit_code=0, operation='DisableEncryption', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message='Decryption started') exit_status = { 'operation': 'DisableEncryption', 'status': CommonVariables.extension_success_status, 'status_code': str(CommonVariables.success), 'message': 'Decryption completed' } hutil.exit_if_same_seq(exit_status) hutil.save_seq() try: extension_parameter = ExtensionParameter(hutil, logger, DistroPatcher, encryption_environment, get_protected_settings(), get_public_settings()) disk_util = DiskUtil(hutil=hutil, patching=DistroPatcher, logger=logger, encryption_environment=encryption_environment) encryption_status = json.loads(disk_util.get_encryption_status()) if encryption_status["os"] != "NotEncrypted": raise Exception("Disabling encryption is not supported when OS volume is encrypted") bek_util = BekUtil(disk_util, logger) encryption_config = EncryptionConfig(encryption_environment, logger) bek_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) crypt_items = disk_util.get_crypt_items() logger.log('Found {0} items to decrypt'.format(len(crypt_items))) for crypt_item in crypt_items: disk_util.create_cleartext_key(crypt_item.mapper_name) add_result = disk_util.luks_add_cleartext_key(bek_passphrase_file, crypt_item.dev_path, crypt_item.mapper_name, crypt_item.luks_header_path) if add_result != CommonVariables.process_success: if disk_util.is_luks_device(crypt_item.dev_path, crypt_item.luks_header_path): raise Exception("luksAdd failed with return code {0}".format(add_result)) else: logger.log("luksAdd failed with return code {0}".format(add_result)) logger.log("Ignoring for now, as device ({0}) does not seem to be a luks device".format(crypt_item.dev_path)) continue if crypt_item.dev_path.startswith("/dev/sd"): logger.log('Updating crypt item entry to use mapper name') logger.log('Device name before update: {0}'.format(crypt_item.dev_path)) crypt_item.dev_path = disk_util.get_persistent_path_by_sdx_path(crypt_item.dev_path) logger.log('Device name after update: {0}'.format(crypt_item.dev_path)) crypt_item.uses_cleartext_key = True disk_util.update_crypt_item(crypt_item, None) logger.log('Added cleartext key for {0}'.format(crypt_item)) decryption_marker.command = extension_parameter.command decryption_marker.volume_type = extension_parameter.VolumeType decryption_marker.commit() hutil.do_exit(exit_code=0, operation='DisableEncryption', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message='Decryption started') except Exception as e: message = "Failed to disable the extension with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.unknown_error, operation='DisableEncryption', status=CommonVariables.extension_error_status, code=str(CommonVariables.unknown_error), message=message) def get_public_settings(): public_settings_str = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('publicSettings') if isinstance(public_settings_str, basestring): return json.loads(public_settings_str) else: return public_settings_str def get_protected_settings(): protected_settings_str = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('protectedSettings') if isinstance(protected_settings_str, basestring): return json.loads(protected_settings_str) else: return protected_settings_str def update_encryption_settings(): hutil.do_parse_context('UpdateEncryptionSettings') logger.log('Updating encryption settings') # re-install extra packages like cryptsetup if no longer on system from earlier enable try: DistroPatcher.install_extras() except Exception as e: message = "Failed to update encryption settings with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) hutil.do_exit(exit_code=CommonVariables.missing_dependency, operation='UpdateEncryptionSettings', status=CommonVariables.extension_error_status, code=str(CommonVariables.missing_dependency), message=message) encryption_config = EncryptionConfig(encryption_environment, logger) config_secret_seq = encryption_config.get_secret_seq_num() current_secret_seq_num = int(config_secret_seq if config_secret_seq else -1) update_call_seq_num = hutil.get_current_seq() logger.log("Current secret was created in operation #{0}".format(current_secret_seq_num)) logger.log("The update call is operation #{0}".format(update_call_seq_num)) executor = CommandExecutor(logger) executor.Execute("mount /boot") try: disk_util = DiskUtil(hutil=hutil, patching=DistroPatcher, logger=logger, encryption_environment=encryption_environment) bek_util = BekUtil(disk_util, logger) extension_parameter = ExtensionParameter(hutil, logger, DistroPatcher, encryption_environment, get_protected_settings(), get_public_settings()) existing_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if current_secret_seq_num < update_call_seq_num: if extension_parameter.passphrase is None or extension_parameter.passphrase == "": extension_parameter.passphrase = bek_util.generate_passphrase(extension_parameter.KeyEncryptionAlgorithm) logger.log('Recreating secret to store in the KeyVault') keyVaultUtil = KeyVaultUtil(logger) temp_keyfile = tempfile.NamedTemporaryFile(delete=False) temp_keyfile.write(extension_parameter.passphrase) temp_keyfile.close() for crypt_item in disk_util.get_crypt_items(): if not crypt_item: continue before_keyslots = disk_util.luks_dump_keyslots(crypt_item.dev_path, crypt_item.luks_header_path) logger.log("Before key addition, keyslots for {0}: {1}".format(crypt_item.dev_path, before_keyslots)) logger.log("Adding new key for {0}".format(crypt_item.dev_path)) luks_add_result = disk_util.luks_add_key(passphrase_file=existing_passphrase_file, dev_path=crypt_item.dev_path, mapper_name=crypt_item.mapper_name, header_file=crypt_item.luks_header_path, new_key_path=temp_keyfile.name) logger.log("luks add result is {0}".format(luks_add_result)) after_keyslots = disk_util.luks_dump_keyslots(crypt_item.dev_path, crypt_item.luks_header_path) logger.log("After key addition, keyslots for {0}: {1}".format(crypt_item.dev_path, after_keyslots)) new_keyslot = list(map(lambda x: x[0] != x[1], zip(before_keyslots, after_keyslots))).index(True) logger.log("New key was added in keyslot {0}".format(new_keyslot)) # crypt_item.current_luks_slot = new_keyslot # disk_util.update_crypt_item(crypt_item) logger.log("New key successfully added to all encrypted devices") if DistroPatcher.distro_info[0] == "Ubuntu": logger.log("Updating initrd image with new osluksheader.") executor.Execute("update-initramfs -u -k all", True) if DistroPatcher.distro_info[0] == "redhat" or DistroPatcher.distro_info[0] == "centos": distro_version = DistroPatcher.distro_info[1] if distro_version.startswith('7.'): logger.log("Updating initrd image with new osluksheader.") executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) os.unlink(temp_keyfile.name) # install Python ADAL support if using client certificate authentication if extension_parameter.AADClientCertThumbprint: DistroPatcher.install_adal() kek_secret_id_created = keyVaultUtil.create_kek_secret(Passphrase=extension_parameter.passphrase, KeyVaultURL=extension_parameter.KeyVaultURL, KeyEncryptionKeyURL=extension_parameter.KeyEncryptionKeyURL, AADClientID=extension_parameter.AADClientID, AADClientCertThumbprint=extension_parameter.AADClientCertThumbprint, KeyEncryptionAlgorithm=extension_parameter.KeyEncryptionAlgorithm, AADClientSecret=extension_parameter.AADClientSecret, DiskEncryptionKeyFileName=extension_parameter.DiskEncryptionKeyFileName) if kek_secret_id_created is None: hutil.do_exit(exit_code=CommonVariables.create_encryption_secret_failed, operation='UpdateEncryptionSettings', status=CommonVariables.extension_error_status, code=str(CommonVariables.create_encryption_secret_failed), message='UpdateEncryptionSettings failed.') else: encryption_config.passphrase_file_name = extension_parameter.DiskEncryptionKeyFileName encryption_config.secret_id = kek_secret_id_created encryption_config.secret_seq_num = hutil.get_current_seq() encryption_config.commit() shutil.copy(existing_passphrase_file, encryption_environment.bek_backup_path) logger.log("Backed up BEK at {0}".format(encryption_environment.bek_backup_path)) hutil.do_exit(exit_code=0, operation='UpdateEncryptionSettings', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message=str(kek_secret_id_created)) else: logger.log('Secret has already been updated') mount_encrypted_disks(disk_util, bek_util, existing_passphrase_file, encryption_config) disk_util.log_lsblk_output() hutil.exit_if_same_seq() # remount bek volume existing_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if extension_parameter.passphrase and extension_parameter.passphrase != file(existing_passphrase_file).read(): logger.log("The new passphrase has not been placed in BEK volume yet") logger.log("Skipping removal of old passphrase") exit_without_status_report() logger.log('Removing old passphrase') for crypt_item in disk_util.get_crypt_items(): if not crypt_item: continue if filecmp.cmp(existing_passphrase_file, encryption_environment.bek_backup_path): logger.log('Current BEK and backup are the same, skipping removal') continue logger.log('Removing old passphrase from {0}'.format(crypt_item.dev_path)) keyslots = disk_util.luks_dump_keyslots(crypt_item.dev_path, crypt_item.luks_header_path) logger.log("Keyslots before removal: {0}".format(keyslots)) luks_remove_result = disk_util.luks_remove_key(passphrase_file=encryption_environment.bek_backup_path, dev_path=crypt_item.dev_path, header_file=crypt_item.luks_header_path) logger.log("luks remove result is {0}".format(luks_remove_result)) keyslots = disk_util.luks_dump_keyslots(crypt_item.dev_path, crypt_item.luks_header_path) logger.log("Keyslots after removal: {0}".format(keyslots)) logger.log("Old key successfully removed from all encrypted devices") if DistroPatcher.distro_info[0] == "Ubuntu": logger.log("Updating initrd image with new osluksheader.") executor.Execute("update-initramfs -u -k all", True) if DistroPatcher.distro_info[0] == "redhat" or DistroPatcher.distro_info[0] == "centos": distro_version = DistroPatcher.distro_info[1] if distro_version.startswith('7.'): logger.log("Updating initrd image with new osluksheader.") executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) hutil.save_seq() extension_parameter.commit() os.unlink(encryption_environment.bek_backup_path) hutil.do_exit(exit_code=0, operation='UpdateEncryptionSettings', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message='Encryption settings updated') except Exception as e: message = "Failed to update encryption settings with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.unknown_error, operation='UpdateEncryptionSettings', status=CommonVariables.extension_error_status, code=str(CommonVariables.unknown_error), message=message) def update(): # The extension update handshake is [old:disable][new:update][old:uninstall][new:install] # this method is called when updating an older version of the extension to a newer version hutil.do_parse_context('Update') logger.log("Installing pre-requisites") DistroPatcher.install_extras() DistroPatcher.update_prereq() hutil.do_exit(0, 'Update', CommonVariables.extension_success_status, '0', 'Update Succeeded') def exit_without_status_report(): sys.exit(0) def not_support_header_option_distro(patching): if patching.distro_info[0].lower() == "centos" and patching.distro_info[1].startswith('6.'): return True if patching.distro_info[0].lower() == "redhat" and patching.distro_info[1].startswith('6.'): return True if patching.distro_info[0].lower() == "suse" and patching.distro_info[1].startswith('11'): return True return False def none_or_empty(obj): if obj is None or obj == "": return True else: return False def toggle_se_linux_for_centos7(disable): if DistroPatcher.distro_info[0].lower() == 'centos' and DistroPatcher.distro_info[1].startswith('7.0'): if disable: se_linux_status = encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': encryption_environment.disable_se_linux() return True else: encryption_environment.enable_se_linux() return False def mount_encrypted_disks(disk_util, bek_util, passphrase_file, encryption_config): # mount encrypted resource disk volume_type = encryption_config.get_volume_type().lower() if volume_type == CommonVariables.VolumeTypeData.lower() or volume_type == CommonVariables.VolumeTypeAll.lower(): resource_disk_util = ResourceDiskUtil(logger, disk_util, passphrase_file, get_public_settings(), DistroPatcher.distro_info) resource_disk_util.automount() logger.log("mounted encrypted resource disk") # add walkaround for the centos 7.0 se_linux_status = None if DistroPatcher.distro_info[0].lower() == 'centos' and DistroPatcher.distro_info[1].startswith('7.0'): se_linux_status = encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': encryption_environment.disable_se_linux() # mount any data disks - make sure the azure disk config path exists. for crypt_item in disk_util.get_crypt_items(): if not crypt_item: continue if not os.path.exists(os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name)): luks_open_result = disk_util.luks_open(passphrase_file=passphrase_file, dev_path=crypt_item.dev_path, mapper_name=crypt_item.mapper_name, header_file=crypt_item.luks_header_path, uses_cleartext_key=crypt_item.uses_cleartext_key) logger.log("luks open result is {0}".format(luks_open_result)) disk_util.mount_crypt_item(crypt_item, passphrase_file) if DistroPatcher.distro_info[0].lower() == 'centos' and DistroPatcher.distro_info[1].startswith('7.0'): if se_linux_status is not None and se_linux_status.lower() == 'enforcing': encryption_environment.enable_se_linux() def main(): global hutil, DistroPatcher, logger, encryption_environment HandlerUtil.LoggerInit('/var/log/waagent.log', '/dev/stdout') HandlerUtil.waagent.Log("{0} started to handle.".format(CommonVariables.extension_name)) hutil = HandlerUtil.HandlerUtility(HandlerUtil.waagent.Log, HandlerUtil.waagent.Error, CommonVariables.extension_name) logger = BackupLogger(hutil) DistroPatcher = GetDistroPatcher(logger) hutil.patching = DistroPatcher encryption_environment = EncryptionEnvironment(patching=DistroPatcher, logger=logger) disk_util = DiskUtil(hutil=hutil, patching=DistroPatcher, logger=logger, encryption_environment=encryption_environment) hutil.disk_util = disk_util if DistroPatcher is None: hutil.do_exit(exit_code=CommonVariables.os_not_supported, operation='Enable', status=CommonVariables.extension_error_status, code=(CommonVariables.os_not_supported), message='Enable failed: the os is not supported') for a in sys.argv[1:]: if re.match("^([-/]*)(disable)", a): disable() elif re.match("^([-/]*)(uninstall)", a): uninstall() elif re.match("^([-/]*)(install)", a): install() elif re.match("^([-/]*)(enable)", a): enable() elif re.match("^([-/]*)(update)", a): update() elif re.match("^([-/]*)(daemon)", a): daemon() def mark_encryption(command, volume_type, disk_format_query): encryption_marker = EncryptionMarkConfig(logger, encryption_environment) encryption_marker.command = command encryption_marker.volume_type = volume_type encryption_marker.diskFormatQuery = disk_format_query encryption_marker.commit() return encryption_marker def is_daemon_running(): handler_path = os.path.join(os.getcwd(), __file__) daemon_arg = "-daemon" psproc = subprocess.Popen(['ps', 'aux'], stdout=subprocess.PIPE) pslist, _ = psproc.communicate() for line in pslist.split("\n"): if handler_path in line and daemon_arg in line: return True return False def enable(): while True: hutil.do_parse_context('Enable') logger.log('Enabling extension') public_settings = get_public_settings() logger.log('Public settings:\n{0}'.format(json.dumps(public_settings, sort_keys=True, indent=4))) cutil = CheckUtil(logger) # Mount already encrypted disks before running fatal prechecks disk_util = DiskUtil(hutil=hutil, patching=DistroPatcher, logger=logger, encryption_environment=encryption_environment) bek_util = BekUtil(disk_util, logger) existing_passphrase_file = None encryption_config = EncryptionConfig(encryption_environment=encryption_environment, logger=logger) existing_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if existing_passphrase_file is not None: mount_encrypted_disks(disk_util=disk_util, bek_util=bek_util, encryption_config=encryption_config, passphrase_file=existing_passphrase_file) # Migrate to early unlock if using crypt mount if disk_util.should_use_azure_crypt_mount(): disk_util.migrate_crypt_items(existing_passphrase_file) encryption_status = json.loads(disk_util.get_encryption_status()) # run fatal prechecks, report error if exceptions are caught try: cutil.precheck_for_fatal_failures(public_settings, encryption_status, DistroPatcher) except Exception as e: logger.log("PRECHECK: Fatal Exception thrown during precheck") logger.log(traceback.format_exc()) msg = e.message hutil.do_exit(exit_code=CommonVariables.configuration_error, operation='Enable', status=CommonVariables.extension_error_status, code=(CommonVariables.configuration_error), message=msg) hutil.disk_util.log_lsblk_output() # run prechecks and log any failures detected try: if cutil.is_non_fatal_precheck_failure(): logger.log("PRECHECK: Precheck failure, incompatible environment suspected") else: logger.log("PRECHECK: Prechecks successful") except Exception: logger.log("PRECHECK: Exception thrown during precheck") logger.log(traceback.format_exc()) encryption_operation = public_settings.get(CommonVariables.EncryptionEncryptionOperationKey) if encryption_operation in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll]: logger.log("handle.py found enable encryption operation") extension_parameter = ExtensionParameter(hutil, logger, DistroPatcher, encryption_environment, get_protected_settings(), public_settings) if os.path.exists(encryption_environment.bek_backup_path) or (extension_parameter.config_file_exists() and extension_parameter.config_changed()): logger.log("Config has changed, updating encryption settings") update_encryption_settings() extension_parameter.commit() else: logger.log("Config did not change or first call, enabling encryption") enable_encryption() elif encryption_operation == CommonVariables.DisableEncryption: logger.log("handle.py found disable encryption operation") disable_encryption() elif encryption_operation == CommonVariables.QueryEncryptionStatus: logger.log("handle.py found query operation") encryption_marker = EncryptionMarkConfig(logger, encryption_environment) if is_daemon_running() or (encryption_marker and not encryption_marker.config_file_exists()): logger.log("A daemon is already running or no operation in progress, exiting without status report") hutil.redo_last_status() exit_without_status_report() else: logger.log("No daemon found, trying to find the last non-query operation") hutil.find_last_nonquery_operation = True else: msg = "Encryption operation {0} is not supported".format(encryption_operation) logger.log(msg) hutil.do_exit(exit_code=CommonVariables.configuration_error, operation='Enable', status=CommonVariables.extension_error_status, code=(CommonVariables.configuration_error), message=msg) def enable_encryption(): hutil.do_parse_context('EnableEncryption') # we need to start another subprocess to do it, because the initial process # would be killed by the wala in 5 minutes. logger.log('Enabling encryption') """ trying to mount the crypted items. """ disk_util = DiskUtil(hutil=hutil, patching=DistroPatcher, logger=logger, encryption_environment=encryption_environment) bek_util = BekUtil(disk_util, logger) existing_passphrase_file = None encryption_config = EncryptionConfig(encryption_environment=encryption_environment, logger=logger) config_path_result = disk_util.make_sure_path_exists(encryption_environment.encryption_config_path) if config_path_result != CommonVariables.process_success: logger.log(msg="azure encryption path creation failed.", level=CommonVariables.ErrorLevel) if encryption_config.config_file_exists(): existing_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if existing_passphrase_file is not None: mount_encrypted_disks(disk_util=disk_util, bek_util=bek_util, encryption_config=encryption_config, passphrase_file=existing_passphrase_file) else: logger.log(msg="EncryptionConfig is present, but could not get the BEK file.", level=CommonVariables.WarningLevel) hutil.redo_last_status() exit_without_status_report() ps = subprocess.Popen(["ps", "aux"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) ps_stdout, ps_stderr = ps.communicate() if re.search(r"dd.*of=/dev/mapper/osencrypt", ps_stdout): logger.log(msg="OS disk encryption already in progress, exiting", level=CommonVariables.WarningLevel) exit_without_status_report() # handle the re-call scenario. the re-call would resume? # if there's one tag for the next reboot. encryption_marker = EncryptionMarkConfig(logger, encryption_environment) try: protected_settings_str = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('protectedSettings') public_settings_str = hutil._context._config['runtimeSettings'][0]['handlerSettings'].get('publicSettings') if isinstance(public_settings_str, basestring): public_settings = json.loads(public_settings_str) else: public_settings = public_settings_str if isinstance(protected_settings_str, basestring): protected_settings = json.loads(protected_settings_str) else: protected_settings = protected_settings_str extension_parameter = ExtensionParameter(hutil, logger, DistroPatcher, encryption_environment, protected_settings, public_settings) kek_secret_id_created = None encryption_marker = EncryptionMarkConfig(logger, encryption_environment) if encryption_marker.config_file_exists(): # verify the encryption mark logger.log(msg="encryption mark is there, starting daemon.", level=CommonVariables.InfoLevel) start_daemon('EnableEncryption') else: encryption_config = EncryptionConfig(encryption_environment, logger) exit_status = None if encryption_config.config_file_exists(): exit_status = { 'operation': 'EnableEncryption', 'status': CommonVariables.extension_success_status, 'status_code': str(CommonVariables.success), 'message': encryption_config.get_secret_id() } hutil.exit_if_same_seq(exit_status) hutil.save_seq() encryption_config.volume_type = extension_parameter.VolumeType encryption_config.commit() if encryption_config.config_file_exists() and existing_passphrase_file is not None: logger.log(msg="config file exists and passphrase file exists.", level=CommonVariables.WarningLevel) encryption_marker = mark_encryption(command=extension_parameter.command, volume_type=extension_parameter.VolumeType, disk_format_query=extension_parameter.DiskFormatQuery) start_daemon('EnableEncryption') else: """ creating the secret, the secret would be transferred to a bek volume after the updatevm called in powershell. """ # store the luks passphrase in the secret. keyVaultUtil = KeyVaultUtil(logger) """ validate the parameters """ if(extension_parameter.VolumeType is None or not any([extension_parameter.VolumeType.lower() == vt.lower() for vt in CommonVariables.SupportedVolumeTypes])): if encryption_config.config_file_exists(): existing_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if existing_passphrase_file is None: logger.log("Unsupported volume type specified and BEK volume does not exist, clearing encryption config") encryption_config.clear_config() hutil.do_exit(exit_code=CommonVariables.configuration_error, operation='EnableEncryption', status=CommonVariables.extension_error_status, code=str(CommonVariables.configuration_error), message='VolumeType "{0}" is not supported'.format(extension_parameter.VolumeType)) if extension_parameter.command not in [CommonVariables.EnableEncryption, CommonVariables.EnableEncryptionFormat, CommonVariables.EnableEncryptionFormatAll]: hutil.do_exit(exit_code=CommonVariables.configuration_error, operation='EnableEncryption', status=CommonVariables.extension_error_status, code=str(CommonVariables.configuration_error), message='Command "{0}" is not supported'.format(extension_parameter.command)) """ this is the fresh call case """ # handle the passphrase related if existing_passphrase_file is None: if extension_parameter.passphrase is None or extension_parameter.passphrase == "": extension_parameter.passphrase = bek_util.generate_passphrase(extension_parameter.KeyEncryptionAlgorithm) else: logger.log(msg="the extension_parameter.passphrase is already defined") # install Python ADAL support if using client certificate authentication if extension_parameter.AADClientCertThumbprint: DistroPatcher.install_adal() kek_secret_id_created = keyVaultUtil.create_kek_secret(Passphrase=extension_parameter.passphrase, KeyVaultURL=extension_parameter.KeyVaultURL, KeyEncryptionKeyURL=extension_parameter.KeyEncryptionKeyURL, AADClientID=extension_parameter.AADClientID, AADClientCertThumbprint=extension_parameter.AADClientCertThumbprint, KeyEncryptionAlgorithm=extension_parameter.KeyEncryptionAlgorithm, AADClientSecret=extension_parameter.AADClientSecret, DiskEncryptionKeyFileName=extension_parameter.DiskEncryptionKeyFileName) if kek_secret_id_created is None: encryption_config.clear_config() hutil.do_exit(exit_code=CommonVariables.create_encryption_secret_failed, operation='EnableEncryption', status=CommonVariables.extension_error_status, code=str(CommonVariables.create_encryption_secret_failed), message='Enable failed.') else: encryption_config.passphrase_file_name = extension_parameter.DiskEncryptionKeyFileName encryption_config.volume_type = extension_parameter.VolumeType encryption_config.secret_id = kek_secret_id_created encryption_config.secret_seq_num = hutil.get_current_seq() encryption_config.commit() extension_parameter.commit() encryption_marker = mark_encryption(command=extension_parameter.command, volume_type=extension_parameter.VolumeType, disk_format_query=extension_parameter.DiskFormatQuery) if kek_secret_id_created: hutil.do_exit(exit_code=0, operation='EnableEncryption', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message=str(kek_secret_id_created)) else: """ the enabling called again. the passphrase would be re-used. """ hutil.do_exit(exit_code=0, operation='EnableEncryption', status=CommonVariables.extension_success_status, code=str(CommonVariables.encrypttion_already_enabled), message=str(kek_secret_id_created)) except Exception as e: message = "Failed to enable the extension with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.unknown_error, operation='EnableEncryption', status=CommonVariables.extension_error_status, code=str(CommonVariables.unknown_error), message=message) def enable_encryption_format(passphrase, disk_format_query, disk_util, force=False): logger.log('enable_encryption_format') logger.log('disk format query is {0}'.format(disk_format_query)) json_parsed = json.loads(disk_format_query) if type(json_parsed) is dict: encryption_format_items = [json_parsed, ] elif type(json_parsed) is list: encryption_format_items = json_parsed else: raise Exception("JSON parse error. Input: {0}".format(disk_format_query)) for encryption_item in encryption_format_items: dev_path_in_query = None if "scsi" in encryption_item and encryption_item["scsi"] != '': dev_path_in_query = disk_util.query_dev_sdx_path_by_scsi_id(encryption_item["scsi"]) if "dev_path" in encryption_item and encryption_item["dev_path"] != '': dev_path_in_query = encryption_item["dev_path"] if not dev_path_in_query: raise Exception("Could not find a device path for Encryption Item: {0}".format(json.dumps(encryption_item))) devices = disk_util.get_device_items(dev_path_in_query) if len(devices) != 1: logger.log(msg=("the device with this path {0} have more than one sub device. so skip it.".format(dev_path_in_query)), level=CommonVariables.WarningLevel) continue else: device_item = devices[0] if device_item.file_system is None or device_item.file_system == "" or force: if device_item.mount_point: disk_util.swapoff() disk_util.umount(device_item.mount_point) mapper_name = str(uuid.uuid4()) logger.log("encrypting " + str(device_item)) encrypted_device_path = os.path.join(CommonVariables.dev_mapper_root, mapper_name) try: se_linux_status = None if DistroPatcher.distro_info[0].lower() == 'centos' and DistroPatcher.distro_info[1].startswith('7.0'): se_linux_status = encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': encryption_environment.disable_se_linux() encrypt_result = disk_util.encrypt_disk(dev_path=dev_path_in_query, passphrase_file=passphrase, mapper_name=mapper_name, header_file=None) finally: if DistroPatcher.distro_info[0].lower() == 'centos' and DistroPatcher.distro_info[1].startswith('7.0'): if se_linux_status is not None and se_linux_status.lower() == 'enforcing': encryption_environment.enable_se_linux() if encrypt_result == CommonVariables.process_success: # TODO: let customer specify the default file system in the # parameter file_system = None if "file_system" in encryption_item and encryption_item["file_system"] != "": file_system = encryption_item["file_system"] else: file_system = CommonVariables.default_file_system format_disk_result = disk_util.format_disk(dev_path=encrypted_device_path, file_system=file_system) if format_disk_result != CommonVariables.process_success: logger.log(msg=("format of disk {0} failed with result: {1}".format(encrypted_device_path, format_disk_result)), level=CommonVariables.ErrorLevel) crypt_item_to_update = CryptItem() crypt_item_to_update.mapper_name = mapper_name crypt_item_to_update.dev_path = dev_path_in_query crypt_item_to_update.luks_header_path = None crypt_item_to_update.file_system = file_system crypt_item_to_update.uses_cleartext_key = False crypt_item_to_update.current_luks_slot = 0 if "name" in encryption_item and encryption_item["name"] != "": crypt_item_to_update.mount_point = os.path.join("/mnt/", str(encryption_item["name"])) else: crypt_item_to_update.mount_point = os.path.join("/mnt/", mapper_name) # allow override through the new full_mount_point field if "full_mount_point" in encryption_item and encryption_item["full_mount_point"] != "": crypt_item_to_update.mount_point = os.path.join(str(encryption_item["full_mount_point"])) logger.log(msg="modifying/removing the entry for unencrypted drive in fstab", level=CommonVariables.InfoLevel) disk_util.modify_fstab_entry_encrypt(crypt_item_to_update.mount_point, os.path.join(CommonVariables.dev_mapper_root, mapper_name)) disk_util.make_sure_path_exists(crypt_item_to_update.mount_point) update_crypt_item_result = disk_util.add_crypt_item(crypt_item_to_update, passphrase) if not update_crypt_item_result: logger.log(msg="update crypt item failed", level=CommonVariables.ErrorLevel) mount_result = disk_util.mount_filesystem(dev_path=encrypted_device_path, mount_point=crypt_item_to_update.mount_point) logger.log(msg=("mount result is {0}".format(mount_result))) else: logger.log(msg="encryption failed with code {0}".format(encrypt_result), level=CommonVariables.ErrorLevel) else: logger.log(msg=("the item fstype is not empty {0}".format(device_item.file_system))) def encrypt_inplace_without_seperate_header_file(passphrase_file, device_item, disk_util, bek_util, status_prefix='', ongoing_item_config=None): """ if ongoing_item_config is not None, then this is a resume case. this function will return the phase """ logger.log("encrypt_inplace_without_seperate_header_file") current_phase = CommonVariables.EncryptionPhaseBackupHeader if ongoing_item_config is None: ongoing_item_config = OnGoingItemConfig(encryption_environment=encryption_environment, logger=logger) ongoing_item_config.current_block_size = CommonVariables.default_block_size ongoing_item_config.current_slice_index = 0 ongoing_item_config.device_size = device_item.size ongoing_item_config.file_system = device_item.file_system ongoing_item_config.luks_header_file_path = None ongoing_item_config.mapper_name = str(uuid.uuid4()) ongoing_item_config.mount_point = device_item.mount_point if os.path.exists(os.path.join('/dev/', device_item.name)): ongoing_item_config.original_dev_name_path = os.path.join('/dev/', device_item.name) ongoing_item_config.original_dev_path = os.path.join('/dev/', device_item.name) else: ongoing_item_config.original_dev_name_path = os.path.join('/dev/mapper/', device_item.name) ongoing_item_config.original_dev_path = os.path.join('/dev/mapper/', device_item.name) ongoing_item_config.phase = CommonVariables.EncryptionPhaseBackupHeader ongoing_item_config.commit() else: logger.log(msg="ongoing item config is not none, this is resuming, info: {0}".format(ongoing_item_config), level=CommonVariables.WarningLevel) logger.log(msg=("encrypting device item: {0}".format(ongoing_item_config.get_original_dev_path()))) # we only support ext file systems. current_phase = ongoing_item_config.get_phase() original_dev_path = ongoing_item_config.get_original_dev_path() mapper_name = ongoing_item_config.get_mapper_name() device_size = ongoing_item_config.get_device_size() luks_header_size = CommonVariables.luks_header_size size_shrink_to = (device_size - luks_header_size) / CommonVariables.sector_size while current_phase != CommonVariables.EncryptionPhaseDone: if current_phase == CommonVariables.EncryptionPhaseBackupHeader: logger.log(msg="the current phase is " + str(CommonVariables.EncryptionPhaseBackupHeader), level=CommonVariables.InfoLevel) # log an appropriate warning if the file system type is not supported device_fs = ongoing_item_config.get_file_system().lower() if not device_fs in CommonVariables.inplace_supported_file_systems: if device_fs in CommonVariables.format_supported_file_systems: msg = "Encrypting {0} file system is not supported for data-preserving encryption. Consider using the encrypt-format-all option.".format(device_fs) else: msg = "AzureDiskEncryption does not support the {0} file system".format(device_fs) logger.log(msg=msg, level=CommonVariables.WarningLevel) ongoing_item_config.clear_config() return current_phase chk_shrink_result = disk_util.check_shrink_fs(dev_path=original_dev_path, size_shrink_to=size_shrink_to) if chk_shrink_result != CommonVariables.process_success: logger.log(msg="check shrink fs failed with code {0} for {1}".format(chk_shrink_result, original_dev_path), level=CommonVariables.ErrorLevel) logger.log(msg="your file system may not have enough space to do the encryption.") # remove the ongoing item. ongoing_item_config.clear_config() return current_phase else: ongoing_item_config.current_slice_index = 0 ongoing_item_config.current_source_path = original_dev_path ongoing_item_config.current_destination = encryption_environment.copy_header_slice_file_path ongoing_item_config.current_total_copy_size = CommonVariables.default_block_size ongoing_item_config.from_end = False ongoing_item_config.header_slice_file_path = encryption_environment.copy_header_slice_file_path ongoing_item_config.original_dev_path = original_dev_path ongoing_item_config.commit() if os.path.exists(encryption_environment.copy_header_slice_file_path): logger.log(msg="the header slice file is there, remove it.", level=CommonVariables.WarningLevel) os.remove(encryption_environment.copy_header_slice_file_path) copy_result = disk_util.copy(ongoing_item_config=ongoing_item_config, status_prefix=status_prefix) if copy_result != CommonVariables.process_success: logger.log(msg="copy the header block failed, return code is: {0}".format(copy_result), level=CommonVariables.ErrorLevel) return current_phase else: ongoing_item_config.current_slice_index = 0 ongoing_item_config.phase = CommonVariables.EncryptionPhaseEncryptDevice ongoing_item_config.commit() current_phase = CommonVariables.EncryptionPhaseEncryptDevice elif current_phase == CommonVariables.EncryptionPhaseEncryptDevice: logger.log(msg="the current phase is {0}".format(CommonVariables.EncryptionPhaseEncryptDevice), level=CommonVariables.InfoLevel) encrypt_result = disk_util.encrypt_disk(dev_path=original_dev_path, passphrase_file=passphrase_file, mapper_name=mapper_name, header_file=None) # after the encrypt_disk without seperate header, then the uuid # would change. if encrypt_result != CommonVariables.process_success: logger.log(msg="encrypt file system failed.", level=CommonVariables.ErrorLevel) return current_phase else: ongoing_item_config.current_slice_index = 0 ongoing_item_config.phase = CommonVariables.EncryptionPhaseCopyData ongoing_item_config.commit() current_phase = CommonVariables.EncryptionPhaseCopyData elif current_phase == CommonVariables.EncryptionPhaseCopyData: logger.log(msg="the current phase is {0}".format(CommonVariables.EncryptionPhaseCopyData), level=CommonVariables.InfoLevel) device_mapper_path = os.path.join(CommonVariables.dev_mapper_root, mapper_name) ongoing_item_config.current_destination = device_mapper_path ongoing_item_config.current_source_path = original_dev_path ongoing_item_config.current_total_copy_size = (device_size - luks_header_size) ongoing_item_config.from_end = True ongoing_item_config.phase = CommonVariables.EncryptionPhaseCopyData ongoing_item_config.commit() copy_result = disk_util.copy(ongoing_item_config=ongoing_item_config, status_prefix=status_prefix) if copy_result != CommonVariables.process_success: logger.log(msg="copy the main content block failed, return code is: {0}".format(copy_result), level=CommonVariables.ErrorLevel) return current_phase else: ongoing_item_config.phase = CommonVariables.EncryptionPhaseRecoverHeader ongoing_item_config.commit() current_phase = CommonVariables.EncryptionPhaseRecoverHeader elif current_phase == CommonVariables.EncryptionPhaseRecoverHeader: logger.log(msg="the current phase is " + str(CommonVariables.EncryptionPhaseRecoverHeader), level=CommonVariables.InfoLevel) ongoing_item_config.from_end = False backed_up_header_slice_file_path = ongoing_item_config.get_header_slice_file_path() ongoing_item_config.current_slice_index = 0 ongoing_item_config.current_source_path = backed_up_header_slice_file_path device_mapper_path = os.path.join(CommonVariables.dev_mapper_root, mapper_name) ongoing_item_config.current_destination = device_mapper_path ongoing_item_config.current_total_copy_size = CommonVariables.default_block_size ongoing_item_config.commit() copy_result = disk_util.copy(ongoing_item_config=ongoing_item_config, status_prefix=status_prefix) if copy_result == CommonVariables.process_success: crypt_item_to_update = CryptItem() crypt_item_to_update.mapper_name = mapper_name original_dev_name_path = ongoing_item_config.get_original_dev_name_path() crypt_item_to_update.dev_path = disk_util.get_persistent_path_by_sdx_path(original_dev_name_path) crypt_item_to_update.luks_header_path = "None" crypt_item_to_update.file_system = ongoing_item_config.get_file_system() crypt_item_to_update.uses_cleartext_key = False crypt_item_to_update.current_luks_slot = 0 # if the original mountpoint is empty, then leave # it as None mount_point = ongoing_item_config.get_mount_point() if mount_point == "" or mount_point is None: crypt_item_to_update.mount_point = "None" else: crypt_item_to_update.mount_point = mount_point update_crypt_item_result = disk_util.add_crypt_item(crypt_item_to_update, passphrase_file) if not update_crypt_item_result: logger.log(msg="update crypt item failed", level=CommonVariables.ErrorLevel) if mount_point: logger.log(msg="removing entry for unencrypted drive from fstab", level=CommonVariables.InfoLevel) disk_util.modify_fstab_entry_encrypt(mount_point, os.path.join(CommonVariables.dev_mapper_root, mapper_name)) else: logger.log(msg=original_dev_name_path + " is not defined in fstab, no need to update", level=CommonVariables.InfoLevel) if os.path.exists(encryption_environment.copy_header_slice_file_path): os.remove(encryption_environment.copy_header_slice_file_path) current_phase = CommonVariables.EncryptionPhaseDone ongoing_item_config.phase = current_phase ongoing_item_config.commit() expand_fs_result = disk_util.expand_fs(dev_path=device_mapper_path) if crypt_item_to_update.mount_point != "None": disk_util.mount_filesystem(device_mapper_path, ongoing_item_config.get_mount_point()) else: logger.log("the crypt_item_to_update.mount_point is None, so we do not mount it.") ongoing_item_config.clear_config() if expand_fs_result != CommonVariables.process_success: logger.log(msg="expand fs result is: {0}".format(expand_fs_result), level=CommonVariables.ErrorLevel) return current_phase else: logger.log(msg="recover header failed result is: {0}".format(copy_result), level=CommonVariables.ErrorLevel) return current_phase def encrypt_inplace_with_seperate_header_file(passphrase_file, device_item, disk_util, bek_util, status_prefix='', ongoing_item_config=None): """ if ongoing_item_config is not None, then this is a resume case. """ logger.log("encrypt_inplace_with_seperate_header_file") current_phase = CommonVariables.EncryptionPhaseEncryptDevice if ongoing_item_config is None: ongoing_item_config = OnGoingItemConfig(encryption_environment=encryption_environment, logger=logger) mapper_name = str(uuid.uuid4()) ongoing_item_config.current_block_size = CommonVariables.default_block_size ongoing_item_config.current_slice_index = 0 ongoing_item_config.device_size = device_item.size ongoing_item_config.file_system = device_item.file_system ongoing_item_config.mapper_name = mapper_name ongoing_item_config.mount_point = device_item.mount_point # TODO improve this. if os.path.exists(os.path.join('/dev/', device_item.name)): ongoing_item_config.original_dev_name_path = os.path.join('/dev/', device_item.name) else: ongoing_item_config.original_dev_name_path = os.path.join('/dev/mapper/', device_item.name) ongoing_item_config.original_dev_path = os.path.join('/dev/disk/by-uuid', device_item.uuid) luks_header_file_path = disk_util.create_luks_header(mapper_name=mapper_name) if luks_header_file_path is None: logger.log(msg="create header file failed", level=CommonVariables.ErrorLevel) return current_phase else: ongoing_item_config.luks_header_file_path = luks_header_file_path ongoing_item_config.phase = CommonVariables.EncryptionPhaseEncryptDevice ongoing_item_config.commit() else: logger.log(msg="ongoing item config is not none, this is resuming: {0}".format(ongoing_item_config), level=CommonVariables.WarningLevel) current_phase = ongoing_item_config.get_phase() while current_phase != CommonVariables.EncryptionPhaseDone: if current_phase == CommonVariables.EncryptionPhaseEncryptDevice: try: mapper_name = ongoing_item_config.get_mapper_name() original_dev_path = ongoing_item_config.get_original_dev_path() luks_header_file_path = ongoing_item_config.get_header_file_path() toggle_se_linux_for_centos7(True) encrypt_result = disk_util.encrypt_disk(dev_path=original_dev_path, passphrase_file=passphrase_file, mapper_name=mapper_name, header_file=luks_header_file_path) if encrypt_result != CommonVariables.process_success: logger.log(msg="the encrypton for {0} failed".format(original_dev_path), level=CommonVariables.ErrorLevel) return current_phase else: ongoing_item_config.phase = CommonVariables.EncryptionPhaseCopyData ongoing_item_config.commit() current_phase = CommonVariables.EncryptionPhaseCopyData finally: toggle_se_linux_for_centos7(False) elif current_phase == CommonVariables.EncryptionPhaseCopyData: try: mapper_name = ongoing_item_config.get_mapper_name() original_dev_path = ongoing_item_config.get_original_dev_path() luks_header_file_path = ongoing_item_config.get_header_file_path() toggle_se_linux_for_centos7(True) device_mapper_path = os.path.join("/dev/mapper", mapper_name) if not os.path.exists(device_mapper_path): open_result = disk_util.luks_open(passphrase_file=passphrase_file, dev_path=original_dev_path, mapper_name=mapper_name, header_file=luks_header_file_path, uses_cleartext_key=False) if open_result != CommonVariables.process_success: logger.log(msg="the luks open for {0} failed.".format(original_dev_path), level=CommonVariables.ErrorLevel) return current_phase else: logger.log(msg="the device mapper path existed, so skip the luks open.", level=CommonVariables.InfoLevel) device_size = ongoing_item_config.get_device_size() current_slice_index = ongoing_item_config.get_current_slice_index() if current_slice_index is None: ongoing_item_config.current_slice_index = 0 ongoing_item_config.current_source_path = original_dev_path ongoing_item_config.current_destination = device_mapper_path ongoing_item_config.current_total_copy_size = device_size ongoing_item_config.from_end = True ongoing_item_config.commit() copy_result = disk_util.copy(ongoing_item_config=ongoing_item_config, status_prefix=status_prefix) if copy_result != CommonVariables.success: error_message = "the copying result is {0} so skip the mounting".format(copy_result) logger.log(msg=(error_message), level=CommonVariables.ErrorLevel) return current_phase else: crypt_item_to_update = CryptItem() crypt_item_to_update.mapper_name = mapper_name original_dev_name_path = ongoing_item_config.get_original_dev_name_path() crypt_item_to_update.dev_path = disk_util.get_persistent_path_by_sdx_path(original_dev_name_path) crypt_item_to_update.luks_header_path = luks_header_file_path crypt_item_to_update.file_system = ongoing_item_config.get_file_system() crypt_item_to_update.uses_cleartext_key = False crypt_item_to_update.current_luks_slot = 0 # if the original mountpoint is empty, then leave # it as None mount_point = ongoing_item_config.get_mount_point() if mount_point is None or mount_point == "": crypt_item_to_update.mount_point = "None" else: crypt_item_to_update.mount_point = mount_point update_crypt_item_result = disk_util.add_crypt_item(crypt_item_to_update, passphrase_file) if not update_crypt_item_result: logger.log(msg="update crypt item failed", level=CommonVariables.ErrorLevel) if crypt_item_to_update.mount_point != "None": disk_util.mount_filesystem(device_mapper_path, mount_point) else: logger.log("the crypt_item_to_update.mount_point is None, so we do not mount it.") if mount_point: logger.log(msg="removing entry for unencrypted drive from fstab", level=CommonVariables.InfoLevel) disk_util.modify_fstab_entry_encrypt(mount_point, os.path.join(CommonVariables.dev_mapper_root, mapper_name)) else: logger.log(msg=original_dev_name_path + " is not defined in fstab, no need to update", level=CommonVariables.InfoLevel) current_phase = CommonVariables.EncryptionPhaseDone ongoing_item_config.phase = current_phase ongoing_item_config.commit() ongoing_item_config.clear_config() return current_phase finally: toggle_se_linux_for_centos7(False) def decrypt_inplace_copy_data(passphrase_file, crypt_item, raw_device_item, mapper_device_item, disk_util, status_prefix='', ongoing_item_config=None): logger.log(msg="decrypt_inplace_copy_data") if ongoing_item_config: logger.log(msg="ongoing item config is not none, resuming decryption, info: {0}".format(ongoing_item_config), level=CommonVariables.WarningLevel) else: logger.log(msg="starting decryption of {0}".format(crypt_item)) ongoing_item_config = OnGoingItemConfig(encryption_environment=encryption_environment, logger=logger) ongoing_item_config.current_destination = crypt_item.dev_path ongoing_item_config.current_source_path = os.path.join(CommonVariables.dev_mapper_root, crypt_item.mapper_name) ongoing_item_config.current_total_copy_size = mapper_device_item.size ongoing_item_config.from_end = True ongoing_item_config.phase = CommonVariables.DecryptionPhaseCopyData ongoing_item_config.current_slice_index = 0 ongoing_item_config.current_block_size = CommonVariables.default_block_size ongoing_item_config.mount_point = crypt_item.mount_point ongoing_item_config.commit() current_phase = ongoing_item_config.get_phase() while current_phase != CommonVariables.DecryptionPhaseDone: logger.log(msg=("the current phase is {0}".format(CommonVariables.EncryptionPhaseBackupHeader)), level=CommonVariables.InfoLevel) if current_phase == CommonVariables.DecryptionPhaseCopyData: copy_result = disk_util.copy(ongoing_item_config=ongoing_item_config, status_prefix=status_prefix) if copy_result == CommonVariables.process_success: mount_point = ongoing_item_config.get_mount_point() if mount_point and mount_point != "None": logger.log(msg="restoring entry for unencrypted drive from fstab", level=CommonVariables.InfoLevel) disk_util.restore_mount_info(ongoing_item_config.get_mount_point()) elif crypt_item.mapper_name: disk_util.restore_mount_info(crypt_item.mapper_name) else: logger.log(msg=crypt_item.dev_path + " was not in fstab when encryption was enabled, no need to restore", level=CommonVariables.InfoLevel) ongoing_item_config.phase = CommonVariables.DecryptionPhaseDone ongoing_item_config.commit() current_phase = CommonVariables.DecryptionPhaseDone else: logger.log(msg="decryption: block copy failed, result: {0}".format(copy_result), level=CommonVariables.ErrorLevel) return current_phase ongoing_item_config.clear_config() return current_phase def decrypt_inplace_without_separate_header_file(passphrase_file, crypt_item, raw_device_item, mapper_device_item, disk_util, status_prefix='', ongoing_item_config=None): logger.log(msg="decrypt_inplace_without_separate_header_file") proc_comm = ProcessCommunicator() executor = CommandExecutor(logger) executor.Execute(DistroPatcher.cryptsetup_path + " luksDump " + crypt_item.dev_path, communicator=proc_comm) luks_header_size = int(re.findall(r"Payload.*?(\d+)", proc_comm.stdout)[0]) * CommonVariables.sector_size if raw_device_item.size - mapper_device_item.size != luks_header_size: logger.log(msg="mismatch between raw and mapper device found for crypt_item {0}".format(crypt_item), level=CommonVariables.ErrorLevel) logger.log(msg="raw_device_item: {0}".format(raw_device_item), level=CommonVariables.ErrorLevel) logger.log(msg="mapper_device_item {0}".format(mapper_device_item), level=CommonVariables.ErrorLevel) return None return decrypt_inplace_copy_data(passphrase_file, crypt_item, raw_device_item, mapper_device_item, disk_util, status_prefix, ongoing_item_config) def decrypt_inplace_with_separate_header_file(passphrase_file, crypt_item, raw_device_item, mapper_device_item, disk_util, status_prefix='', ongoing_item_config=None): logger.log(msg="decrypt_inplace_with_separate_header_file") if raw_device_item.size != mapper_device_item.size: logger.log(msg="mismatch between raw and mapper device found for crypt_item {0}".format(crypt_item), level=CommonVariables.ErrorLevel) logger.log(msg="raw_device_item: {0}".format(raw_device_item), level=CommonVariables.ErrorLevel) logger.log(msg="mapper_device_item {0}".format(mapper_device_item), level=CommonVariables.ErrorLevel) return return decrypt_inplace_copy_data(passphrase_file, crypt_item, raw_device_item, mapper_device_item, disk_util, status_prefix, ongoing_item_config) def enable_encryption_all_format(passphrase_file, encryption_marker, disk_util, bek_util): """ In case of success return None, otherwise return the device item which failed. """ logger.log(msg="executing the enable_encryption_all_format command") device_items = find_all_devices_to_encrypt(encryption_marker, disk_util, bek_util) # Don't encrypt partitions that are not even mounted device_items_to_encrypt = filter(lambda di: di.mount_point is not None and di.mount_point != "", device_items) dev_path_reference_table = disk_util.get_block_device_to_azure_udev_table() device_items_to_encrypt = filter(lambda di: os.path.join('/dev/', di.name) in dev_path_reference_table, device_items_to_encrypt) msg = 'Encrypting and formatting {0} data volumes'.format(len(device_items_to_encrypt)) logger.log(msg) hutil.do_status_report(operation='EnableEncryptionFormatAll', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) return encrypt_format_device_items(passphrase_file, device_items_to_encrypt, disk_util, True) def encrypt_format_device_items(passphrase, device_items, disk_util, force=False): """ Formats the block devices represented by the supplied device_item. This is done by constructing a disk format query based on the supplied device items and passing it on to the enable_encryption_format method. Returns None if all items are successfully format-encrypted Otherwise returns the device item which failed. """ # use the new udev names for formatting and later on for cryptmounting dev_path_reference_table = disk_util.get_block_device_to_azure_udev_table() def single_device_item_to_format_query_dict(device_item): """ Converts a single device_item into an dictionary than will be later "json-stringified" """ format_query_element = {} dev_path = os.path.join('/dev/', device_item.name) if dev_path in dev_path_reference_table: format_query_element["dev_path"] = dev_path_reference_table[dev_path] else: format_query_element["dev_path"] = dev_path # introduce a new "full_mount_point" field below to avoid the /mnt/ prefix that automatically gets appended format_query_element["full_mount_point"] = str(device_item.mount_point) format_query_element["file_system"] = str(device_item.file_system) return format_query_element disk_format_query = json.dumps(map(single_device_item_to_format_query_dict, device_items)) return enable_encryption_format(passphrase, disk_format_query, disk_util, force) def find_all_devices_to_encrypt(encryption_marker, disk_util, bek_util): device_items = disk_util.get_device_items(None) device_items_to_encrypt = [] special_azure_devices_to_skip = disk_util.get_azure_devices() for device_item in device_items: logger.log("device_item == " + str(device_item)) should_skip = disk_util.should_skip_for_inplace_encryption(device_item, special_azure_devices_to_skip, encryption_marker.get_volume_type()) if not should_skip and \ not any(di.name == device_item.name for di in device_items_to_encrypt): device_items_to_encrypt.append(device_item) return device_items_to_encrypt def enable_encryption_all_in_place(passphrase_file, encryption_marker, disk_util, bek_util): """ if return None for the success case, or return the device item which failed. """ logger.log(msg="executing the enable_encryption_all_in_place command.") device_items_to_encrypt = find_all_devices_to_encrypt(encryption_marker, disk_util, bek_util) msg = 'Encrypting {0} data volumes'.format(len(device_items_to_encrypt)) logger.log(msg) hutil.do_status_report(operation='EnableEncryption', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) for device_num, device_item in enumerate(device_items_to_encrypt): umount_status_code = CommonVariables.success if device_item.mount_point is not None and device_item.mount_point != "": umount_status_code = disk_util.umount(device_item.mount_point) if umount_status_code != CommonVariables.success: logger.log("error occured when do the umount for: {0} with code: {1}".format(device_item.mount_point, umount_status_code)) else: logger.log(msg=("encrypting: {0}".format(device_item))) no_header_file_support = not_support_header_option_distro(DistroPatcher) status_prefix = "Encrypting data volume {0}/{1}".format(device_num + 1, len(device_items_to_encrypt)) # TODO check the file system before encrypting it. if no_header_file_support: logger.log(msg="this is the centos 6 or redhat 6 or sles 11 series, need to resize data drive", level=CommonVariables.WarningLevel) encryption_result_phase = encrypt_inplace_without_seperate_header_file(passphrase_file=passphrase_file, device_item=device_item, disk_util=disk_util, bek_util=bek_util, status_prefix=status_prefix) else: encryption_result_phase = encrypt_inplace_with_seperate_header_file(passphrase_file=passphrase_file, device_item=device_item, disk_util=disk_util, bek_util=bek_util, status_prefix=status_prefix) if encryption_result_phase == CommonVariables.EncryptionPhaseDone: continue else: # do exit to exit from this round return device_item return None def disable_encryption_all_in_place(passphrase_file, decryption_marker, disk_util): """ On success, returns None. Otherwise returns the crypt item for which decryption failed. """ logger.log(msg="executing disable_encryption_all_in_place") device_items = disk_util.get_device_items(None) crypt_items = disk_util.get_crypt_items() msg = 'Decrypting {0} data volumes'.format(len(crypt_items)) logger.log(msg) hutil.do_status_report(operation='DisableEncryption', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=msg) for crypt_item_num, crypt_item in enumerate(crypt_items): logger.log("processing crypt_item: " + str(crypt_item)) def raw_device_item_match(device_item): sdx_device_name = os.path.join("/dev/", device_item.name) if crypt_item.dev_path.startswith(CommonVariables.disk_by_id_root): return crypt_item.dev_path == disk_util.query_dev_id_path_by_sdx_path(sdx_device_name) else: return crypt_item.dev_path == sdx_device_name def mapped_device_item_match(device_item): return crypt_item.mapper_name == device_item.name raw_device_item = next((d for d in device_items if raw_device_item_match(d)), None) mapper_device_item = next((d for d in device_items if mapped_device_item_match(d)), None) if not raw_device_item: logger.log("raw device not found for crypt_item {0}".format(crypt_item), level='Warn') logger.log("Skipping device", level='Warn') continue if not mapper_device_item: logger.log("mapper device not found for crypt_item {0}".format(crypt_item)) if disk_util.is_luks_device(crypt_item.dev_path, crypt_item.luks_header_path): logger.log("Found a luks device for this device item, yet couldn't open mapper: {0}".format(crypt_item)) logger.log("Failing".format(crypt_item)) return crypt_item else: continue decryption_result_phase = None status_prefix = "Decrypting data volume {0}/{1}".format(crypt_item_num + 1, len(crypt_items)) if crypt_item.luks_header_path: decryption_result_phase = decrypt_inplace_with_separate_header_file(passphrase_file=passphrase_file, crypt_item=crypt_item, raw_device_item=raw_device_item, mapper_device_item=mapper_device_item, disk_util=disk_util, status_prefix=status_prefix) else: decryption_result_phase = decrypt_inplace_without_separate_header_file(passphrase_file=passphrase_file, crypt_item=crypt_item, raw_device_item=raw_device_item, mapper_device_item=mapper_device_item, disk_util=disk_util, status_prefix=status_prefix) if decryption_result_phase == CommonVariables.DecryptionPhaseDone: disk_util.luks_close(crypt_item.mapper_name) disk_util.remove_crypt_item(crypt_item) #disk_util.mount_all() continue else: # decryption failed for a crypt_item, return the failed item to caller return crypt_item disk_util.mount_all() return None def daemon_encrypt(): # Ensure the same configuration is executed only once # If the previous enable failed, we do not have retry logic here. # TODO Remount all encryption_marker = EncryptionMarkConfig(logger, encryption_environment) if encryption_marker.config_file_exists(): logger.log("encryption is marked.") """ search for the bek volume, then mount it:) """ disk_util = DiskUtil(hutil, DistroPatcher, logger, encryption_environment) encryption_config = EncryptionConfig(encryption_environment, logger) bek_passphrase_file = None """ try to find the attached bek volume, and use the file to mount the crypted volumes, and if the passphrase file is found, then we will re-use it for the future. """ bek_util = BekUtil(disk_util, logger) if encryption_config.config_file_exists(): bek_passphrase_file = bek_util.get_bek_passphrase_file(encryption_config) if bek_passphrase_file is None: hutil.do_exit(exit_code=CommonVariables.passphrase_file_not_found, operation='EnableEncryption', status=CommonVariables.extension_error_status, code=CommonVariables.passphrase_file_not_found, message='Passphrase file not found.') executor = CommandExecutor(logger) is_not_in_stripped_os = bool(executor.Execute("mountpoint /oldroot")) volume_type = encryption_config.get_volume_type().lower() if (volume_type == CommonVariables.VolumeTypeData.lower() or volume_type == CommonVariables.VolumeTypeAll.lower()) and \ is_not_in_stripped_os: try: while not daemon_encrypt_data_volumes(encryption_marker=encryption_marker, encryption_config=encryption_config, disk_util=disk_util, bek_util=bek_util, bek_passphrase_file=bek_passphrase_file): logger.log("Calling daemon_encrypt_data_volumes again") except Exception as e: message = "Failed to encrypt data volumes with error: {0}, stack trace: {1}".format(e, traceback.format_exc()) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message=message) else: hutil.do_status_report(operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message='Encryption succeeded for data volumes') disk_util.log_lsblk_output() mount_encrypted_disks(disk_util, bek_util, bek_passphrase_file, encryption_config) if volume_type == CommonVariables.VolumeTypeOS.lower() or \ volume_type == CommonVariables.VolumeTypeAll.lower(): # import OSEncryption here instead of at the top because it relies # on pre-req packages being installed (specifically, python-six on Ubuntu) distro_name = DistroPatcher.distro_info[0] distro_version = DistroPatcher.distro_info[1] os_encryption = None if (((distro_name == 'redhat' and distro_version == '7.3') or (distro_name == 'redhat' and distro_version == '7.4') or (distro_name == 'redhat' and distro_version == '7.5') or (distro_name == 'redhat' and distro_version == '7.6') or (distro_name == 'redhat' and distro_version == '7.7')) and (disk_util.is_os_disk_lvm() or os.path.exists('/volumes.lvm'))): from oscrypto.rhel_72_lvm import RHEL72LVMEncryptionStateMachine os_encryption = RHEL72LVMEncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif (((distro_name == 'centos' and distro_version == '7.3.1611') or (distro_name == 'centos' and distro_version.startswith('7.4')) or (distro_name == 'centos' and distro_version.startswith('7.5')) or (distro_name == 'centos' and distro_version.startswith('7.6')) or (distro_name == 'centos' and distro_version.startswith('7.7'))) and (disk_util.is_os_disk_lvm() or os.path.exists('/volumes.lvm'))): from oscrypto.rhel_72_lvm import RHEL72LVMEncryptionStateMachine os_encryption = RHEL72LVMEncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif ((distro_name == 'redhat' and distro_version == '7.2') or (distro_name == 'redhat' and distro_version == '7.3') or (distro_name == 'redhat' and distro_version == '7.4') or (distro_name == 'redhat' and distro_version == '7.5') or (distro_name == 'redhat' and distro_version == '7.6') or (distro_name == 'redhat' and distro_version == '7.7') or (distro_name == 'centos' and distro_version.startswith('7.7')) or (distro_name == 'centos' and distro_version.startswith('7.6')) or (distro_name == 'centos' and distro_version.startswith('7.5')) or (distro_name == 'centos' and distro_version.startswith('7.4')) or (distro_name == 'centos' and distro_version == '7.3.1611') or (distro_name == 'centos' and distro_version == '7.2.1511')): from oscrypto.rhel_72 import RHEL72EncryptionStateMachine os_encryption = RHEL72EncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif distro_name == 'redhat' and distro_version == '6.8': from oscrypto.rhel_68 import RHEL68EncryptionStateMachine os_encryption = RHEL68EncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif distro_name == 'centos' and (distro_version == '6.8' or distro_version == '6.9'): from oscrypto.centos_68 import CentOS68EncryptionStateMachine os_encryption = CentOS68EncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif distro_name == 'Ubuntu' and distro_version in ['16.04', '18.04']: from oscrypto.ubuntu_1604 import Ubuntu1604EncryptionStateMachine os_encryption = Ubuntu1604EncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) elif distro_name == 'Ubuntu' and distro_version == '14.04': from oscrypto.ubuntu_1404 import Ubuntu1404EncryptionStateMachine os_encryption = Ubuntu1404EncryptionStateMachine(hutil=hutil, distro_patcher=DistroPatcher, logger=logger, encryption_environment=encryption_environment) else: message = "OS volume encryption is not supported on {0} {1}".format(distro_name, distro_version) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message=message) try: os_encryption.start_encryption() if not os_encryption.state == 'completed': raise Exception("did not reach completed state") else: encryption_marker.clear_config() except Exception as e: message = "Failed to encrypt OS volume with error: {0}, stack trace: {1}, machine state: {2}".format(e, traceback.format_exc(), os_encryption.state) logger.log(msg=message, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message=message) message = '' if volume_type == CommonVariables.VolumeTypeAll.lower(): message = 'Encryption succeeded for all volumes' else: message = 'Encryption succeeded for OS volume' logger.log(msg=message) hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message=message) def daemon_encrypt_data_volumes(encryption_marker, encryption_config, disk_util, bek_util, bek_passphrase_file): try: """ check whether there's a scheduled encryption task """ mount_all_result = disk_util.mount_all() if mount_all_result != CommonVariables.process_success: logger.log(msg="mount all failed with code:{0}".format(mount_all_result), level=CommonVariables.ErrorLevel) """ TODO: resuming the encryption for rebooting suddenly scenario we need the special handling is because the half done device can be a error state: say, the file system header missing.so it could be identified. """ ongoing_item_config = OnGoingItemConfig(encryption_environment=encryption_environment, logger=logger) if ongoing_item_config.config_file_exists(): logger.log("OngoingItemConfig exists.") ongoing_item_config.load_value_from_file() header_file_path = ongoing_item_config.get_header_file_path() mount_point = ongoing_item_config.get_mount_point() status_prefix = "Resuming encryption after reboot" if not none_or_empty(mount_point): logger.log("mount point is not empty {0}, trying to unmount it first.".format(mount_point)) umount_status_code = disk_util.umount(mount_point) logger.log("unmount return code is {0}".format(umount_status_code)) if none_or_empty(header_file_path): encryption_result_phase = encrypt_inplace_without_seperate_header_file(passphrase_file=bek_passphrase_file, device_item=None, disk_util=disk_util, bek_util=bek_util, status_prefix=status_prefix, ongoing_item_config=ongoing_item_config) # TODO mount it back when shrink failed else: encryption_result_phase = encrypt_inplace_with_seperate_header_file(passphrase_file=bek_passphrase_file, device_item=None, disk_util=disk_util, bek_util=bek_util, status_prefix=status_prefix, ongoing_item_config=ongoing_item_config) """ if the resuming failed, we should fail. """ if encryption_result_phase != CommonVariables.EncryptionPhaseDone: original_dev_path = ongoing_item_config.get_original_dev_path message = 'EnableEncryption: resuming encryption for {0} failed'.format(original_dev_path) raise Exception(message) else: ongoing_item_config.clear_config() else: logger.log("OngoingItemConfig does not exist") failed_item = None if not encryption_marker.config_file_exists(): logger.log("Data volumes are not marked for encryption") return True if encryption_marker.get_current_command() == CommonVariables.EnableEncryption: failed_item = enable_encryption_all_in_place(passphrase_file=bek_passphrase_file, encryption_marker=encryption_marker, disk_util=disk_util, bek_util=bek_util) elif encryption_marker.get_current_command() == CommonVariables.EnableEncryptionFormat: disk_format_query = encryption_marker.get_encryption_disk_format_query() failed_item = enable_encryption_format(passphrase=bek_passphrase_file, disk_format_query=disk_format_query, disk_util=disk_util) elif encryption_marker.get_current_command() == CommonVariables.EnableEncryptionFormatAll: failed_item = enable_encryption_all_format(passphrase_file=bek_passphrase_file, encryption_marker=encryption_marker, disk_util=disk_util, bek_util=bek_util) else: message = "Command {0} not supported.".format(encryption_marker.get_current_command()) logger.log(msg=message, level=CommonVariables.ErrorLevel) raise Exception(message) if failed_item: message = 'Encryption failed for {0}'.format(failed_item) raise Exception(message) else: return True except Exception: raise def daemon_decrypt(): decryption_marker = DecryptionMarkConfig(logger, encryption_environment) if not decryption_marker.config_file_exists(): logger.log("decryption is not marked.") return logger.log("decryption is marked.") # mount and then unmount all the encrypted items # in order to set-up all the mapper devices # we don't need the BEK since all the drives that need decryption were made cleartext-key unlockable by first call to disable disk_util = DiskUtil(hutil, DistroPatcher, logger, encryption_environment) encryption_config = EncryptionConfig(encryption_environment, logger) mount_encrypted_disks(disk_util=disk_util, bek_util=None, encryption_config=encryption_config, passphrase_file=None) disk_util.umount_all_crypt_items() # at this point all the /dev/mapper/* crypt devices should be open ongoing_item_config = OnGoingItemConfig(encryption_environment=encryption_environment, logger=logger) if ongoing_item_config.config_file_exists(): logger.log("ongoing item config exists.") else: logger.log("ongoing item config does not exist.") failed_item = None if decryption_marker.get_current_command() == CommonVariables.DisableEncryption: failed_item = disable_encryption_all_in_place(passphrase_file=None, decryption_marker=decryption_marker, disk_util=disk_util) else: raise Exception("command {0} not supported.".format(decryption_marker.get_current_command())) if failed_item is not None: hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='Disable', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message='Decryption failed for {0}'.format(failed_item)) else: encryption_config.clear_config() logger.log("clearing the decryption mark after successful decryption") decryption_marker.clear_config() hutil.do_exit(exit_code=0, operation='Disable', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message='Decryption succeeded') def daemon(): hutil.find_last_nonquery_operation = True hutil.do_parse_context('Executing') lock = ProcessLock(logger, encryption_environment.daemon_lock_file_path) if not lock.try_lock(): logger.log("there's another daemon running, please wait it to exit.", level=CommonVariables.WarningLevel) return logger.log("daemon lock acquired sucessfully.") logger.log("waiting for 2 minutes before continuing the daemon") time.sleep(120) logger.log("Installing pre-requisites") DistroPatcher.install_extras() # try decrypt, if decryption marker exists decryption_marker = DecryptionMarkConfig(logger, encryption_environment) if decryption_marker.config_file_exists(): try: daemon_decrypt() except Exception as e: error_msg = ("Failed to disable the extension with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) logger.log(msg=error_msg, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='Disable', status=CommonVariables.extension_error_status, code=str(CommonVariables.encryption_failed), message=error_msg) finally: lock.release_lock() logger.log("returned to daemon") logger.log("exiting daemon") return # try encrypt, in absence of decryption marker try: daemon_encrypt() except Exception as e: # mount the file systems back. error_msg = ("Failed to enable the extension with error: {0}, stack trace: {1}".format(e, traceback.format_exc())) logger.log(msg=error_msg, level=CommonVariables.ErrorLevel) hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='Enable', status=CommonVariables.extension_error_status, code=str(CommonVariables.encryption_failed), message=error_msg) else: encryption_marker = EncryptionMarkConfig(logger, encryption_environment) # TODO not remove it, backed it up. logger.log("returned to daemon successfully after encryption") logger.log("clearing the encryption mark.") encryption_marker.clear_config() hutil.redo_current_status() finally: lock.release_lock() logger.log("exiting daemon") def start_daemon(operation): # This process will start a new background process by calling # extension_shim.sh -c handle.py -daemon # to run the script and will exit itself immediatelly. shim_path = os.path.join(os.getcwd(), CommonVariables.extension_shim_filename) shim_opts = '-c ' + os.path.join(os.getcwd(), __file__) + ' -daemon' args = [shim_path, shim_opts] logger.log("start_daemon with args: {0}".format(args)) # Redirect stdout and stderr to /dev/null. Otherwise daemon process will # throw broken pipe exception when parent process exit. devnull = open(os.devnull, 'w') subprocess.Popen(args, stdout=devnull, stderr=devnull) encryption_config = EncryptionConfig(encryption_environment, logger) if encryption_config.config_file_exists(): hutil.do_exit(exit_code=0, operation=operation, status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message=encryption_config.get_secret_id()) else: hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation=operation, status=CommonVariables.extension_error_status, code=str(CommonVariables.encryption_failed), message='Encryption config not found.') if __name__ == '__main__': main() ================================================ FILE: VMEncryption/main/oscrypto/91ade/50-udev-ade.rules ================================================ ACTION=="add|change", SUBSYSTEM=="block", ATTRS{device_id}=="?00000000-0000-*", ATTR{partition}=="ENCRYPTED_DISK_PARTITION" GOTO="process_disk" GOTO="disk_end" LABEL="process_disk" ATTR{partition}=="ENCRYPTED_DISK_PARTITION", ENV{ID_FS_UUID}="osencrypt-locked" ATTR{partition}=="ENCRYPTED_DISK_PARTITION", ENV{ID_FS_UUID_ENC}="osencrypt-locked" ATTR{partition}=="ENCRYPTED_DISK_PARTITION", ENV{ID_FS_USAGE}="crypto" LABEL="disk_end" ================================================ FILE: VMEncryption/main/oscrypto/91ade/cryptroot-ask-ade.sh ================================================ #!/bin/sh # -*- mode: shell-script; indent-tabs-mode: nil; sh-basic-offset: 4; -*- # ex: ts=8 sw=4 sts=4 et filetype=sh set -x PATH=/usr/sbin:/usr/bin:/sbin:/bin NEWROOT=${NEWROOT:-"/sysroot"} # do not ask, if we already have root [ -f $NEWROOT/proc ] && exit 0 # check if destination already exists [ -b /dev/mapper/$2 ] && exit 0 # we already asked for this device [ -f /tmp/cryptroot-ade-asked-$2 ] && exit 0 # load dm_crypt if it is not already loaded [ -d /sys/module/dm_crypt ] || modprobe dm_crypt . /lib/dracut-crypt-lib.sh # default luksname - luks-UUID luksname=$2 # fallback to passphrase ask_passphrase=1 # if device name is /dev/dm-X, convert to /dev/mapper/name if [ "${1##/dev/dm-}" != "$1" ]; then device="/dev/mapper/$(dmsetup info -c --noheadings -o name "$1")" else device="$1" fi numtries=${3:-10} # # Open LUKS device # info "luksOpen $device $luksname" ls /mnt/azure_bek_disk/LinuxPassPhraseFileName* || (mkdir -p /mnt/azure_bek_disk/ && mount -L "BEK VOLUME" /mnt/azure_bek_disk/) for luksfile in $(ls /mnt/azure_bek_disk/LinuxPassPhraseFileName*); do break; done cryptsetupopts="--header /osluksheader" if [ -n "$luksfile" -a "$luksfile" != "none" -a -e "$luksfile" ]; then if cryptsetup --key-file "$luksfile" $cryptsetupopts luksOpen "$device" "$luksname"; then ask_passphrase=0 fi else if [ $numtries -eq 0 ]; then warn "No key found for $device. Fallback to passphrase mode." else sleep 1 info "No key found for $device. Will try $numtries time(s) more later." initqueue --unique --onetime --settled \ --name cryptroot-ask-ade-$luksname \ $(command -v cryptroot-ask-ade) "$device" "$luksname" "$(($numtries-1))" exit 0 fi fi if [ $ask_passphrase -ne 0 ]; then luks_open="$(command -v cryptsetup) $cryptsetupopts luksOpen" ask_for_password --ply-tries 5 \ --ply-cmd "$luks_open -T1 $device $luksname" \ --ply-prompt "Password ($device)" \ --tty-tries 1 \ --tty-cmd "$luks_open -T5 $device $luksname" unset luks_open fi umount /mnt/azure_bek_disk unset device luksname luksfile # mark device as asked >> /tmp/cryptroot-ade-asked-$2 need_shutdown udevsettle exit 0 ================================================ FILE: VMEncryption/main/oscrypto/91ade/module-setup.sh ================================================ #!/bin/bash # vim: set tabstop=8 shiftwidth=4 softtabstop=4 expandtab smarttab colorcolumn=80: # depends() { echo crypt systemd return 0 } install() { inst_script "$moddir"/cryptroot-ask-ade.sh /sbin/cryptroot-ask-ade inst_hook cmdline 30 "$moddir/parse-crypt-ade.sh" inst_rules "$moddir/50-udev-ade.rules" inst_multiple /etc/services inst /boot/luks/osluksheader /osluksheader dracut_need_initqueue } ================================================ FILE: VMEncryption/main/oscrypto/91ade/parse-crypt-ade.sh ================================================ #!/bin/sh # -*- mode: shell-script; indent-tabs-mode: nil; sh-basic-offset: 4; -*- # ex: ts=8 sw=4 sts=4 et filetype=sh set -x { echo 'SUBSYSTEM!="block", GOTO="luks_ade_end"' echo 'ACTION!="add|change", GOTO="luks_ade_end"' } > /etc/udev/rules.d/70-luks-ade.rules.new { printf -- 'ATTRS{device_id}=="?00000000-0000-*", ENV{ID_FS_UUID}=="osencrypt-locked",' printf -- 'RUN+="%s --settled --unique --onetime ' $(command -v initqueue) printf -- '--name cryptroot-ask-ade-%%k %s ' $(command -v cryptroot-ask-ade) printf -- '$env{DEVNAME} osencrypt"\n' } >> /etc/udev/rules.d/70-luks-ade.rules.new echo 'LABEL="luks_ade_end"' >> /etc/udev/rules.d/70-luks-ade.rules.new mv /etc/udev/rules.d/70-luks-ade.rules.new /etc/udev/rules.d/70-luks-ade.rules ================================================ FILE: VMEncryption/main/oscrypto/OSEncryptionState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os.path import re from collections import namedtuple from uuid import UUID from Common import * from CommandExecutor import * from BekUtil import * from DiskUtil import * from EncryptionConfig import * class OSEncryptionState(object): def __init__(self, state_name, context): super(OSEncryptionState, self).__init__() self.state_name = state_name self.context = context self.state_executed = False self.state_marker = os.path.join(self.context.encryption_environment.os_encryption_markers_path, self.state_name) self.command_executor = CommandExecutor(self.context.logger) self.disk_util = DiskUtil(hutil=self.context.hutil, patching=self.context.distro_patcher, logger=self.context.logger, encryption_environment=self.context.encryption_environment) self.bek_util = BekUtil(disk_util=self.disk_util, logger=self.context.logger) self.encryption_config = EncryptionConfig(encryption_environment=self.context.encryption_environment, logger=self.context.logger) rootfs_mountpoint = '/' if self._is_in_memfs_root(): rootfs_mountpoint = '/oldroot' self.rootfs_sdx_path = self._get_fs_partition(rootfs_mountpoint)[0] if self.rootfs_sdx_path == "none": self.context.logger.log("self.rootfs_sdx_path is none, parsing UUID from fstab") self.rootfs_sdx_path = self._parse_uuid_from_fstab('/') self.context.logger.log("rootfs_uuid: {0}".format(self.rootfs_sdx_path)) if self.rootfs_sdx_path and (self.rootfs_sdx_path.startswith("/dev/disk/by-uuid/") or self._is_uuid(self.rootfs_sdx_path)): self.rootfs_sdx_path = self.disk_util.query_dev_sdx_path_by_uuid(self.rootfs_sdx_path) self.context.logger.log("self.rootfs_sdx_path: {0}".format(self.rootfs_sdx_path)) self.rootfs_disk = None self.rootfs_block_device = None self.bootfs_block_device = None if self.disk_util.is_os_disk_lvm(): proc_comm = ProcessCommunicator() self.command_executor.Execute('pvs', True, communicator=proc_comm) for line in proc_comm.stdout.split("\n"): if "rootvg" in line: self.rootfs_block_device = line.strip().split()[0] self.rootfs_disk = self.rootfs_block_device[:-1] self.bootfs_block_device = self.rootfs_disk + '2' elif not self.rootfs_sdx_path: self.rootfs_disk = '/dev/sda' self.rootfs_block_device = '/dev/sda2' self.bootfs_block_device = '/dev/sda1' elif self.rootfs_sdx_path == '/dev/mapper/osencrypt' or self.rootfs_sdx_path.startswith('/dev/dm-'): self.rootfs_block_device = '/dev/mapper/osencrypt' bootfs_uuid = self._parse_uuid_from_fstab('/boot') self.context.logger.log("bootfs_uuid: {0}".format(bootfs_uuid)) self.bootfs_block_device = self.disk_util.query_dev_sdx_path_by_uuid(bootfs_uuid) else: self.rootfs_block_device = self.disk_util.query_dev_id_path_by_sdx_path(self.rootfs_sdx_path) if not self.rootfs_block_device.startswith('/dev/disk/by-id/'): self.context.logger.log("rootfs_block_device: {0}".format(self.rootfs_block_device)) raise Exception("Could not find rootfs block device") self.rootfs_disk = self.rootfs_block_device[:self.rootfs_block_device.index("-part")] self.bootfs_block_device = self.rootfs_disk + "-part2" if self._get_block_device_size(self.bootfs_block_device) > self._get_block_device_size(self.rootfs_block_device): self.context.logger.log("Swapping partition identifiers for rootfs and bootfs") self.rootfs_block_device, self.bootfs_block_device = self.bootfs_block_device, self.rootfs_block_device self.context.logger.log("rootfs_disk: {0}".format(self.rootfs_disk)) self.context.logger.log("rootfs_block_device: {0}".format(self.rootfs_block_device)) self.context.logger.log("bootfs_block_device: {0}".format(self.bootfs_block_device)) def should_enter(self): self.context.logger.log("OSEncryptionState.should_enter() called for {0}".format(self.state_name)) if self.state_executed: self.context.logger.log("State {0} has already executed, not entering".format(self.state_name)) return False if not os.path.exists(self.state_marker): self.context.logger.log("State marker {0} does not exist, state {1} can be entered".format(self.state_marker, self.state_name)) return True else: self.context.logger.log("State marker {0} exists, state {1} has already executed".format(self.state_marker, self.state_name)) return False def should_exit(self): self.context.logger.log("OSEncryptionState.should_exit() called for {0}".format(self.state_name)) if not os.path.exists(self.state_marker): self.disk_util.make_sure_path_exists(self.context.encryption_environment.os_encryption_markers_path) self.context.logger.log("Creating state marker {0}".format(self.state_marker)) self.disk_util.touch_file(self.state_marker) self.state_executed = True self.context.logger.log("state_executed for {0}: {1}".format(self.state_name, self.state_executed)) return self.state_executed def _get_fs_partition(self, fs): result = None dev = os.lstat(fs).st_dev for line in file('/proc/mounts'): line = [s.decode('string_escape') for s in line.split()[:3]] if dev == os.lstat(line[1]).st_dev: result = tuple(line) return result def _is_in_memfs_root(self): mounts = file('/proc/mounts', 'r').read() return bool(re.search(r'/\s+tmpfs', mounts)) def _parse_uuid_from_fstab(self, mountpoint): contents = file('/etc/fstab', 'r').read() matches = re.findall(r'UUID=(.*?)\s+{0}\s+'.format(mountpoint), contents) if matches: return matches[0] def _get_block_device_size(self, dev): if not os.path.exists(dev): return 0 proc_comm = ProcessCommunicator() self.command_executor.Execute('blockdev --getsize64 {0}'.format(dev), raise_exception_on_failure=True, communicator=proc_comm) return int(proc_comm.stdout.strip()) def _is_uuid(self, s): try: UUID(s) except: return False else: return True OSEncryptionStateContext = namedtuple('OSEncryptionStateContext', ['hutil', 'distro_patcher', 'logger', 'encryption_environment']) ================================================ FILE: VMEncryption/main/oscrypto/OSEncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from OSEncryptionState import * from Common import * from CommandExecutor import * from DiskUtil import * import logging class NullHandler(logging.Handler): def emit(self, record): pass logging.getLogger(__name__).addHandler(NullHandler()) logging.NullHandler = NullHandler from transitions import * class OSEncryptionStateMachine(object): states = [ State(name='uninitialized'), State(name='completed') ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' } ] def on_enter_state(self): self.state_objs[self.state].enter() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return self.state_objs[self.state].should_exit() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(OSEncryptionStateMachine, self).__init__() self.hutil = hutil self.distro_patcher = distro_patcher self.logger = logger self.encryption_environment = encryption_environment self.command_executor = CommandExecutor(self.logger) self.context = OSEncryptionStateContext(hutil=self.hutil, distro_patcher=self.distro_patcher, logger=self.logger, encryption_environment=self.encryption_environment) self.state_machine = Machine(model=self, states=OSEncryptionStateMachine.states, transitions=OSEncryptionStateMachine.transitions, initial='uninitialized') def log_machine_state(self): self.logger.log("======= MACHINE STATE: {0} =======".format(self.state)) def start_encryption(self): self.skip_encryption() self.log_machine_state() def _reboot(self): self.command_executor.Execute('reboot') ================================================ FILE: VMEncryption/main/oscrypto/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from OSEncryptionState import * from OSEncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/centos_68/CentOS68EncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class CentOS68EncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='selinux', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='split_root_partition', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_selinux', 'source': 'prereq', 'dest': 'selinux', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_stripdown', 'source': 'selinux', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_split_root_partition', 'source': 'unmount_oldroot', 'dest': 'split_root_partition', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'split_root_partition', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(CentOS68EncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(CentOS68EncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(CentOS68EncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'selinux': SelinuxState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'split_root_partition': SplitRootPartitionState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=CentOS68EncryptionStateMachine.states, transitions=CentOS68EncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="mount", raise_exception_on_failure=True, communicator=proc_comm) if '/dev/mapper/osencrypt' in proc_comm.stdout: self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_selinux() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_split_root_partition() self.log_machine_state() self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/centos_68/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from CentOS68EncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptpatches/centos_68_dracut.patch ================================================ diff -Naur 90crypt.orig/cryptroot-ask.sh 90crypt/cryptroot-ask.sh --- 90crypt.orig/cryptroot-ask.sh 2016-11-20 18:43:12.697422815 -0800 +++ 90crypt/cryptroot-ask.sh 2016-11-20 18:43:28.033101905 -0800 @@ -64,6 +64,25 @@ # Open LUKS device # +MountPoint=/tmp-keydisk-mount +KeyFileName=LinuxPassPhraseFileName +echo "Trying to get the key from disks ..." >&2 +mkdir -p $MountPoint >&2 +modprobe vfat >/dev/null >&2 +modprobe fuse >/dev/null >&2 +for SFS in /dev/sd*; do + echo "> Trying device:$SFS..." >&2 + mount ${SFS}1 $MountPoint -t vfat -r >&2 + if [ -f $MountPoint/$KeyFileName ]; then + echo "> keyfile got..." >&2 + cp $MountPoint/$KeyFileName /tmp-keyfile + luksfile=/tmp-keyfile + umount $MountPoint + break + fi +done + info "luksOpen $device $luksname $luksfile" if [ -n "$luksfile" -a "$luksfile" != "none" -a -e "$luksfile" ]; then diff -Naur 90crypt.orig/parse-crypt.sh 90crypt/parse-crypt.sh --- 90crypt.orig/parse-crypt.sh 2016-11-20 18:43:12.698422813 -0800 +++ 90crypt/parse-crypt.sh 2016-11-20 18:43:28.033101905 -0800 @@ -12,13 +12,13 @@ echo '. /lib/dracut-lib.sh' > /emergency/90-crypt.sh for luksid in $LUKS; do luksid=${luksid##luks-} - printf 'ENV{ID_FS_TYPE}=="crypto_LUKS", ENV{ID_FS_UUID}=="%s*", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%%k /sbin/cryptroot-ask $env{DEVNAME} luks-$env{ID_FS_UUID}"\n' $luksid \ + printf 'KERNEL=="sda1", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%%k /sbin/cryptroot-ask $env{DEVNAME} osencrypt"\n' $luksid \ >> /etc/udev/rules.d/70-luks.rules - printf '[ -e /dev/disk/by-uuid/*%s* ] || exit 1 \n' $luksid >> /initqueue-finished/crypt.sh - printf '[ -e /dev/disk/by-uuid/*%s* ] || warn "crypto LUKS UUID "%s" not found" \n' $luksid $luksid >> /emergency/90-crypt.sh + printf '[ -e /dev/mapper/osencrypt ] || ( /sbin/cryptroot-ask /dev/sda1 osencrypt && [ -e /dev/mapper/osencrypt ] ) || exit 1 \n' $luksid >> /initqueue-finished/crypt.sh + printf '[ -e /dev/mapper/osencrypt ] || warn "crypto LUKS UUID "%s" not found" \n' $luksid $luksid >> /emergency/90-crypt.sh done else - echo 'ENV{ID_FS_TYPE}=="crypto_LUKS", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%k /sbin/cryptroot-ask $env{DEVNAME} luks-$env{ID_FS_UUID}"' \ + echo 'KERNEL="sda1", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%k /sbin/cryptroot-ask $env{DEVNAME} osencrypt"' \ >> /etc/udev/rules.d/70-luks.rules fi echo 'LABEL="luks_end"' >> /etc/udev/rules.d/70-luks.rules ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/EncryptBlockDeviceState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import re import os import sys from inspect import ismethod from time import sleep from OSEncryptionState import * class EncryptBlockDeviceState(OSEncryptionState): def __init__(self, context): super(EncryptBlockDeviceState, self).__init__('EncryptBlockDeviceState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter encrypt_block_device state") if not super(EncryptBlockDeviceState, self).should_enter(): return False self.context.logger.log("Performing enter checks for encrypt_block_device state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering encrypt_block_device state") self.context.logger.log("Resizing " + self.rootfs_block_device) current_rootfs_size = self._get_root_fs_size_in_sectors(sector_size=512) desired_rootfs_size = current_rootfs_size - 8192 self.command_executor.Execute('e2fsck -yf {0}'.format(self.rootfs_block_device), True) self.command_executor.Execute('resize2fs {0} {1}s'.format(self.rootfs_block_device, desired_rootfs_size), True) self.command_executor.Execute('mount /boot', False) # self._find_bek_and_execute_action('_dump_passphrase') self.context.hutil.do_status_report(operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message='OS disk encryption started') self._find_bek_and_execute_action('_luks_reencrypt') def should_exit(self): self.context.logger.log("Verifying if machine should exit encrypt_block_device state") if not os.path.exists('/dev/mapper/osencrypt'): self._find_bek_and_execute_action('_luks_open') self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return super(EncryptBlockDeviceState, self).should_exit() def _luks_open(self, bek_path): self.command_executor.Execute('cryptsetup luksOpen {0} osencrypt -d {1}'.format(self.rootfs_block_device, bek_path), raise_exception_on_failure=True) def _luks_reencrypt(self, bek_path): self.command_executor.ExecuteInBash('cat {0} | cryptsetup-reencrypt -N --reduce-device-size 8192s {1} -v'.format(bek_path, self.rootfs_block_device), raise_exception_on_failure=True) def _dump_passphrase(self, bek_path): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="od -c {0}".format(bek_path), raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Passphrase:") self.context.logger.log(proc_comm.stdout) def _find_bek_and_execute_action(self, callback_method_name): callback_method = getattr(self, callback_method_name) if not ismethod(callback_method): raise Exception("{0} is not a method".format(callback_method_name)) bek_path = self.bek_util.get_bek_passphrase_file(self.encryption_config) callback_method(bek_path) def _get_root_fs_size_in_sectors(self, sector_size): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="dumpe2fs -h {0}".format(self.rootfs_block_device), raise_exception_on_failure=True, communicator=proc_comm) root_fs_block_count = re.findall(r'Block count:\s*(\d+)', proc_comm.stdout) root_fs_block_size = re.findall(r'Block size:\s*(\d+)', proc_comm.stdout) if not root_fs_block_count or not root_fs_block_size: raise Exception("Error parsing dumpe2fs output, count={0}, size={1}".format(root_fs_block_count, root_fs_block_size)) root_fs_block_count = int(root_fs_block_count[0]) root_fs_block_size = int(root_fs_block_size[0]) return (root_fs_block_count * root_fs_block_size) / sector_size ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/PatchBootSystemState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import re import os import sys from time import sleep from OSEncryptionState import * class PatchBootSystemState(OSEncryptionState): def __init__(self, context): super(PatchBootSystemState, self).__init__('PatchBootSystemState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter patch_boot_system state") if not super(PatchBootSystemState, self).should_enter(): return False self.context.logger.log("Performing enter checks for patch_boot_system state") self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering patch_boot_system state") self.command_executor.Execute('mount /boot', False) self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.Execute('mkdir /oldroot/memroot', True) self.command_executor.Execute('pivot_root /oldroot /oldroot/memroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /memroot/$i /$i; done', True) try: self._modify_pivoted_oldroot() except Exception as e: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') raise else: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') extension_full_name = 'Microsoft.Azure.Security.' + CommonVariables.extension_name self.command_executor.Execute('cp -ax' + ' /var/log/azure/{0}'.format(extension_full_name) + ' /oldroot/var/log/azure/{0}.Stripdown'.format(extension_full_name), True) self.command_executor.Execute('umount /boot') self.command_executor.Execute('umount /oldroot') self.context.logger.log("Pivoted back into memroot successfully, restarting WALA") self.command_executor.Execute('service sshd restart') self.command_executor.Execute('service atd restart') with open("/restart-wala.sh", "w") as f: f.write("service waagent restart\n") with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.should_exit() self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit patch_boot_system state") return super(PatchBootSystemState, self).should_exit() def _append_contents_to_file(self, contents, path): with open(path, 'a') as f: f.write(contents) def _modify_pivoted_oldroot(self): self.context.logger.log("Pivoted into oldroot successfully") scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) patchesdir = os.path.join(scriptdir, '../encryptpatches') patchpath = os.path.join(patchesdir, 'centos_68_dracut.patch') if not os.path.exists(patchpath): message = "Patch not found at path: {0}".format(patchpath) self.context.logger.log(message) raise Exception(message) else: self.context.logger.log("Patch found at path: {0}".format(patchpath)) self.disk_util.remove_mount_info('/') self.disk_util.append_mount_info('/dev/mapper/osencrypt', '/') self.command_executor.ExecuteInBash('patch -b -d /usr/share/dracut/modules.d/90crypt -p1 <{0}'.format(patchpath), True) self._append_contents_to_file('\nadd_drivers+=" fuse vfat nls_cp437 nls_iso8859-1"\n', '/etc/dracut.conf') self._append_contents_to_file('\nadd_dracutmodules+=" crypt"\n', '/etc/dracut.conf') self.command_executor.Execute('/sbin/dracut -f -v', True) self.command_executor.ExecuteInBash('mv -f /boot/initramfs* /boot/boot/', True) with open("/boot/boot/grub/grub.conf", "r") as f: contents = f.read() contents = re.sub(r"rd_NO_LUKS ", r"", contents) contents = re.sub(r"root=(.*?)\s", r"root=/dev/mapper/osencrypt rd_LUKS_UUID=osencrypt rdinitdebug ", contents) contents = re.sub(r"hd0,0", r"hd0,1", contents) with open("/boot/boot/grub/grub.conf", "w") as f: f.write(contents) grub_input = "root (hd0,1)\nsetup (hd0)\nquit\n" self.command_executor.Execute('grub', input=grub_input, raise_exception_on_failure=True) ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/PrereqState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * from pprint import pprint class PrereqState(OSEncryptionState): def __init__(self, context): super(PrereqState, self).__init__('PrereqState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter prereq state") if not super(PrereqState, self).should_enter(): return False self.context.logger.log("Performing enter checks for prereq state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering prereq state") distro_info = self.context.distro_patcher.distro_info self.context.logger.log("Distro info: {0}".format(distro_info)) if ((distro_info[0] == 'redhat' and distro_info[1] == '6.8') or (distro_info[0] == 'centos' and (distro_info[1] == '6.8' or distro_info[1] == '6.9'))): self.context.logger.log("Enabling OS volume encryption on {0} {1}".format(distro_info[0], distro_info[1])) else: raise Exception("CentOS68EncryptionStateMachine called for distro {0} {1}".format(distro_info[0], distro_info[1])) self.context.distro_patcher.install_extras() self._patch_waagent() self.command_executor.Execute('telinit u', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit prereq state") return super(PrereqState, self).should_exit() def _patch_waagent(self): self.context.logger.log("Patching waagent") contents = None with open('/etc/waagent.conf', 'r') as f: contents = f.read() contents = re.sub(r'ResourceDisk.EnableSwap=.', 'ResourceDisk.EnableSwap=n', contents) with open('/etc/waagent.conf', 'w') as f: f.write(contents) self.context.logger.log("waagent patched successfully") ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/SelinuxState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * class SelinuxState(OSEncryptionState): def __init__(self, context): super(SelinuxState, self).__init__('SelinuxState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter selinux state") if not super(SelinuxState, self).should_enter(): return False self.context.logger.log("Performing enter checks for selinux state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering selinux state") se_linux_status = self.context.encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': self.context.logger.log("SELinux is in enforcing mode, disabling") self.context.encryption_environment.disable_se_linux() def should_exit(self): self.context.logger.log("Verifying if machine should exit selinux state") return super(SelinuxState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/SplitRootPartitionState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import re import sys import parted from time import sleep from OSEncryptionState import * class SplitRootPartitionState(OSEncryptionState): def __init__(self, context): super(SplitRootPartitionState, self).__init__('SplitRootPartitionState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter split_root_partition state") if not super(SplitRootPartitionState, self).should_enter(): return False self.context.logger.log("Performing enter checks for split_root_partition state") self.command_executor.Execute("e2fsck -yf {0}".format(self.rootfs_block_device), True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering split_root_partition state") device = parted.getDevice(self.rootfs_disk) disk = parted.Disk(device) original_root_fs_size = self._get_root_fs_size_in(device.sectorSize) self.context.logger.log("Original root filesystem size (sectors): {0}".format(original_root_fs_size)) desired_boot_partition_size = self._size_to_sectors(256, 'MiB', device.sectorSize) self.context.logger.log("Desired boot partition size (sectors): {0}".format(desired_boot_partition_size)) root_partition = disk.partitions[0] original_root_partition_start = root_partition.geometry.start original_root_partition_end = root_partition.geometry.end self.context.logger.log("Original root partition start (sectors): {0}".format(original_root_partition_start)) self.context.logger.log("Original root partition end (sectors): {0}".format(original_root_partition_end)) desired_root_partition_start = original_root_partition_start desired_root_partition_end = original_root_partition_end - desired_boot_partition_size desired_root_partition_size = desired_root_partition_end - desired_root_partition_start self.context.logger.log("Desired root partition start (sectors): {0}".format(desired_root_partition_start)) self.context.logger.log("Desired root partition end (sectors): {0}".format(desired_root_partition_end)) self.context.logger.log("Desired root partition size (sectors): {0}".format(desired_root_partition_size)) self.context.logger.log("Resizing root filesystem") desired_root_fs_size = desired_root_partition_size self._resize_root_fs_to_sectors(desired_root_fs_size, device.sectorSize) desired_root_partition_geometry = parted.Geometry(device=device, start=desired_root_partition_start, length=desired_root_partition_size) root_partition_constraint = parted.Constraint(exactGeom=desired_root_partition_geometry) disk.setPartitionGeometry(partition=root_partition, constraint=root_partition_constraint, start=desired_root_partition_start, end=desired_root_partition_end) desired_boot_partition_start = disk.getFreeSpaceRegions()[1].start desired_boot_partition_end = disk.getFreeSpaceRegions()[1].end desired_boot_partition_size = disk.getFreeSpaceRegions()[1].length self.context.logger.log("Desired boot partition start (sectors): {0}".format(desired_boot_partition_start)) self.context.logger.log("Desired boot partition end (sectors): {0}".format(desired_boot_partition_end)) desired_boot_partition_geometry = parted.Geometry(device=device, start=desired_boot_partition_start, length=desired_boot_partition_size) boot_partition_constraint = parted.Constraint(exactGeom=desired_boot_partition_geometry) desired_boot_partition = parted.Partition(disk=disk, type=parted.PARTITION_NORMAL, geometry=desired_boot_partition_geometry) disk.addPartition(partition=desired_boot_partition, constraint=boot_partition_constraint) disk.commit() probed_root_fs = parted.probeFileSystem(disk.partitions[0].geometry) if not probed_root_fs == 'ext4': raise Exception("Probed root fs is not ext4") disk.partitions[1].setFlag(parted.PARTITION_BOOT) disk.commit() self.command_executor.Execute("partprobe", False) retry_counter = 0 while not os.path.exists(self.bootfs_block_device) and retry_counter < 10: sleep(5) self.command_executor.Execute("partprobe", False) retry_counter += 1 self.command_executor.Execute("mkfs.ext2 {0}".format(self.bootfs_block_device), True) boot_partition_uuid = self._get_uuid(self.bootfs_block_device) # Move stuff from /oldroot/boot to new partition, make new partition mountable at the same spot self.command_executor.Execute("mount {0} /oldroot".format(self.rootfs_block_device), True) self.command_executor.Execute("mkdir /oldroot/memroot", True) self.command_executor.Execute("mount --make-rprivate /", True) self.command_executor.Execute("pivot_root /oldroot /oldroot/memroot", True) self.command_executor.ExecuteInBash("for i in dev proc sys; do mount --move /memroot/$i /$i; done", True) self.command_executor.Execute("mv /boot /boot.backup", True) self.command_executor.Execute("mkdir /boot", True) self.disk_util.remove_mount_info("/boot") self._append_boot_partition_uuid_to_fstab(boot_partition_uuid) self.command_executor.Execute("cp /etc/fstab /memroot/etc/fstab", True) self.command_executor.Execute("mount /boot", True) self.command_executor.Execute("mkdir /boot/boot", True) self.command_executor.ExecuteInBash("shopt -s dotglob && mv /boot.backup/* /boot/boot/", True) self.command_executor.Execute("rmdir /boot.backup", True) self.command_executor.Execute("mount --make-rprivate /", True) self.command_executor.Execute("pivot_root /memroot /memroot/oldroot", True) self.command_executor.Execute("rmdir /oldroot/memroot", True) self.command_executor.ExecuteInBash("for i in dev proc sys; do mount --move /oldroot/$i /$i; done", True) self.command_executor.Execute("umount /oldroot/boot", True) try: self.command_executor.Execute("umount /oldroot", True) except: self.context.logger.log("Could not unmount /oldroot, attempting to restart WALA and unmount again") self.command_executor.Execute('at -f /restart-wala.sh now + 1 minutes', True) self.command_executor.Execute('service waagent stop', True) os.unlink('/var/lib/azure_disk_encryption_config/os_encryption_markers/UnmountOldrootState') self.should_exit() raise def should_exit(self): self.context.logger.log("Verifying if machine should exit split_root_partition state") self.command_executor.ExecuteInBash("mount /boot || mountpoint /boot", True) self.command_executor.ExecuteInBash("[ -e /boot/boot/grub ]", True) self.command_executor.Execute("umount /boot", True) return super(SplitRootPartitionState, self).should_exit() def _size_to_sectors(self, bytes_, unit, sector_size): exponents = { "B": 1, # byte "kB": 1000**1, # kilobyte "MB": 1000**2, # megabyte "GB": 1000**3, # gigabyte "TB": 1000**4, # terabyte "PB": 1000**5, # petabyte "EB": 1000**6, # exabyte "ZB": 1000**7, # zettabyte "YB": 1000**8, # yottabyte "KiB": 1024**1, # kibibyte "MiB": 1024**2, # mebibyte "GiB": 1024**3, # gibibyte "TiB": 1024**4, # tebibyte "PiB": 1024**5, # pebibyte "EiB": 1024**6, # exbibyte "ZiB": 1024**7, # zebibyte "YiB": 1024**8 # yobibyte } if unit not in exponents.keys(): raise SyntaxError("{:} is not a valid SI or IEC byte unit".format(unit)) else: return bytes_ * exponents[unit] // sector_size def _get_uuid(self, partition_name): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="blkid -s UUID -o value {0}".format(partition_name), raise_exception_on_failure=True, communicator=proc_comm) return proc_comm.stdout.strip() def _append_boot_partition_uuid_to_fstab(self, boot_partition_uuid): self.context.logger.log("Updating fstab") contents = None with open('/etc/fstab', 'r') as f: contents = f.read() contents += '\n' contents += 'UUID={0}\t/boot\text2\tdefaults\t0 0'.format(boot_partition_uuid) contents += '\n' with open('/etc/fstab', 'w') as f: f.write(contents) self.context.logger.log("fstab updated successfully") def _get_root_fs_size_in(self, sector_size): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="dumpe2fs -h {0}".format(self.rootfs_block_device), raise_exception_on_failure=True, communicator=proc_comm) root_fs_block_count = re.findall(r'Block count:\s*(\d+)', proc_comm.stdout) root_fs_block_size = re.findall(r'Block size:\s*(\d+)', proc_comm.stdout) if not root_fs_block_count or not root_fs_block_size: raise Exception("Error parsing dumpe2fs output, count={0}, size={1}".format(root_fs_block_count, root_fs_block_size)) root_fs_block_count = int(root_fs_block_count[0]) root_fs_block_size = int(root_fs_block_size[0]) root_fs_size = self._size_to_sectors(root_fs_block_count * root_fs_block_size, 'B', sector_size) return root_fs_size def _resize_root_fs_to_sectors(self, desired_root_fs_size, sectorSize): self.context.logger.log("Desired root filesystem size (sectors): {0}".format(desired_root_fs_size)) self.command_executor.Execute("resize2fs {0} {1}s".format(self.rootfs_block_device, desired_root_fs_size), True) resized_root_fs_size = self._get_root_fs_size_in(sectorSize) self.context.logger.log("Resized root filesystem size (sectors): {0}".format(resized_root_fs_size)) ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/StripdownState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from OSEncryptionState import * class StripdownState(OSEncryptionState): def __init__(self, context): super(StripdownState, self).__init__('StripdownState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter stripdown state") if not super(StripdownState, self).should_enter(): return False self.context.logger.log("Performing enter checks for stripdown state") self.command_executor.Execute('rm -rf /tmp/tmproot', True) self.command_executor.ExecuteInBash('! [ -e "/oldroot" ]', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering stripdown state") self.command_executor.Execute('umount -a') self.command_executor.Execute('mkdir /tmp/tmproot', True) self.command_executor.Execute('mount -t tmpfs none /tmp/tmproot', True) self.command_executor.ExecuteInBash('for i in proc sys dev run usr var tmp root oldroot boot; do mkdir /tmp/tmproot/$i; done', True) self.command_executor.ExecuteInBash('for i in bin etc mnt sbin lib lib64 root; do cp -ax /$i /tmp/tmproot/; done', True) self.command_executor.ExecuteInBash('for i in bin sbin libexec lib lib64 share; do cp -ax /usr/$i /tmp/tmproot/usr/; done', True) self.command_executor.ExecuteInBash('for i in lib local lock opt run spool tmp; do cp -ax /var/$i /tmp/tmproot/var/; done', True) self.command_executor.ExecuteInBash('mkdir /tmp/tmproot/var/log', True) self.command_executor.ExecuteInBash('cp -ax /var/log/azure /tmp/tmproot/var/log/', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.ExecuteInBash('[ -e "/tmp/tmproot/var/lib/azure_disk_encryption_config/azure_crypt_request_queue.ini" ]', True) self.command_executor.Execute('service waagent stop', True) self.command_executor.Execute('pivot_root /tmp/tmproot /tmp/tmproot/oldroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys; do mount --move /oldroot/$i /$i; done', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit stripdown state") if not os.path.exists(self.state_marker): self.context.logger.log("First call to stripdown state (pid={0}), restarting process".format(os.getpid())) # create the marker, but do not advance the state machine super(StripdownState, self).should_exit() # the restarted process shall see the marker and advance the state machine self.command_executor.Execute('service atd restart', True) os.chdir('/') with open("/restart-wala.sh", "w") as f: f.write("service waagent restart\n") self.command_executor.Execute('at -f /restart-wala.sh now + 1 minutes', True) self.context.hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message="Restarted extension from stripped down OS") else: self.context.logger.log("Second call to stripdown state (pid={0}), continuing process".format(os.getpid())) return True ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/UnmountOldrootState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import re import sys from time import sleep from OSEncryptionState import * class UnmountOldrootState(OSEncryptionState): def __init__(self, context): super(UnmountOldrootState, self).__init__('UnmountOldrootState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter unmount_oldroot state") if not super(UnmountOldrootState, self).should_enter(): return False self.context.logger.log("Performing enter checks for unmount_oldroot state") self.command_executor.ExecuteInBash('[ -e "/oldroot" ]', True) if self.command_executor.Execute('mountpoint /oldroot') != 0: return False return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering unmount_oldroot state") self.command_executor.ExecuteInBash('mkdir -p /var/empty/sshd', True) self.command_executor.Execute('service sshd restart') self.command_executor.Execute('dhclient') proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="/sbin/service --status-all", raise_exception_on_failure=True, communicator=proc_comm) for line in proc_comm.stdout.split('\n'): if not "running" in line: continue if "waagent" in line or "ssh" in line: continue splitted = line.split() if len(splitted): service = splitted[0] self.command_executor.Execute('service {0} restart'.format(service)) self.command_executor.Execute('umount -a') self.command_executor.Execute('mount -t proc proc /proc') self.command_executor.Execute('mount -t sysfs sysfs /sys') self.command_executor.Execute('swapoff -a', True) self.bek_util.umount_azure_passhprase(self.encryption_config, force=True) if os.path.exists("/oldroot/mnt/resource"): self.command_executor.Execute('umount /oldroot/mnt/resource') if os.path.exists("/oldroot/mnt"): self.command_executor.Execute('umount /oldroot/mnt') if os.path.exists("/oldroot/mnt/azure_bek_disk"): self.command_executor.Execute('umount /oldroot/mnt/azure_bek_disk') if os.path.exists("/mnt"): self.command_executor.Execute('umount /mnt') if os.path.exists("/mnt/azure_bek_disk"): self.command_executor.Execute('umount /mnt/azure_bek_disk') self.command_executor.Execute('umount /oldroot/mnt/resource') self.command_executor.Execute('umount /oldroot/boot') self.command_executor.Execute('umount /oldroot/misc') self.command_executor.Execute('umount /oldroot/net') self.command_executor.Execute('telinit u', True) self.command_executor.Execute('kill 1', True) proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="fuser -vm /oldroot", raise_exception_on_failure=False, communicator=proc_comm) self.context.logger.log("Processes using oldroot:\n{0}".format(proc_comm.stdout)) procs_to_kill = filter(lambda p: p.isdigit(), proc_comm.stdout.split()) procs_to_kill = reversed(sorted(procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("Restarting WALA before committing suicide") self.context.logger.log("Current executable path: " + sys.executable) self.context.logger.log("Current executable arguments: " + " ".join(sys.argv)) # Kill any other daemons that are blocked and would be executed after this process commits # suicide self.command_executor.Execute('service atd restart') os.chdir('/') with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) sleep(3) self.command_executor.ExecuteInBash('for mp in `grep /oldroot /proc/mounts | cut -f2 -d\' \' | sort -r`; do umount $mp; done', True) sleep(3) attempt = 1 while True: if attempt > 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Attempt #{0} for reloading udev rules".format(attempt)) self.command_executor.ExecuteInBash('pkill -f .*udev.*') self.command_executor.ExecuteInBash('udevd &') self.command_executor.ExecuteInBash('udevadm control --reload-rules && sleep 3') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 self.command_executor.Execute('e2fsck -yf {0}'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/centos_68/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from SelinuxState import * from StripdownState import * from UnmountOldrootState import * from SplitRootPartitionState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/RHEL68EncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class RHEL68EncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='selinux', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_selinux', 'source': 'prereq', 'dest': 'selinux', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_stripdown', 'source': 'selinux', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'unmount_oldroot', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(RHEL68EncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(RHEL68EncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(RHEL68EncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'selinux': SelinuxState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=RHEL68EncryptionStateMachine.states, transitions=RHEL68EncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="mount", raise_exception_on_failure=True, communicator=proc_comm) if '/dev/mapper/osencrypt' in proc_comm.stdout: self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_selinux() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from RHEL68EncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptpatches/rhel_68_dracut.patch ================================================ diff -Naur 90crypt.orig/cryptroot-ask.sh 90crypt/cryptroot-ask.sh --- 90crypt.orig/cryptroot-ask.sh 2016-11-20 18:43:12.697422815 -0800 +++ 90crypt/cryptroot-ask.sh 2016-11-20 18:43:28.033101905 -0800 @@ -64,6 +64,25 @@ # Open LUKS device # +MountPoint=/tmp-keydisk-mount +KeyFileName=LinuxPassPhraseFileName +echo "Trying to get the key from disks ..." >&2 +mkdir -p $MountPoint >&2 +modprobe vfat >/dev/null >&2 +modprobe fuse >/dev/null >&2 +for SFS in /dev/sd*; do + echo "> Trying device:$SFS..." >&2 + mount ${SFS}1 $MountPoint -t vfat -r >&2 + if [ -f $MountPoint/$KeyFileName ]; then + echo "> keyfile got..." >&2 + cp $MountPoint/$KeyFileName /tmp-keyfile + luksfile=/tmp-keyfile + umount $MountPoint + break + fi +done + info "luksOpen $device $luksname $luksfile" if [ -n "$luksfile" -a "$luksfile" != "none" -a -e "$luksfile" ]; then diff -Naur 90crypt.orig/parse-crypt.sh 90crypt/parse-crypt.sh --- 90crypt.orig/parse-crypt.sh 2016-11-20 18:43:12.698422813 -0800 +++ 90crypt/parse-crypt.sh 2016-11-20 18:43:28.033101905 -0800 @@ -12,13 +12,13 @@ echo '. /lib/dracut-lib.sh' > /emergency/90-crypt.sh for luksid in $LUKS; do luksid=${luksid##luks-} - printf 'ENV{ID_FS_TYPE}=="crypto_LUKS", ENV{ID_FS_UUID}=="%s*", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%%k /sbin/cryptroot-ask $env{DEVNAME} luks-$env{ID_FS_UUID}"\n' $luksid \ + printf 'KERNEL=="sda2", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%%k /sbin/cryptroot-ask $env{DEVNAME} osencrypt"\n' $luksid \ >> /etc/udev/rules.d/70-luks.rules - printf '[ -e /dev/disk/by-uuid/*%s* ] || exit 1 \n' $luksid >> /initqueue-finished/crypt.sh - printf '[ -e /dev/disk/by-uuid/*%s* ] || warn "crypto LUKS UUID "%s" not found" \n' $luksid $luksid >> /emergency/90-crypt.sh + printf '[ -e /dev/mapper/osencrypt ] || ( /sbin/cryptroot-ask /dev/sda2 osencrypt && [ -e /dev/mapper/osencrypt ] ) || exit 1 \n' $luksid >> /initqueue-finished/crypt.sh + printf '[ -e /dev/mapper/osencrypt ] || warn "crypto LUKS UUID "%s" not found" \n' $luksid $luksid >> /emergency/90-crypt.sh done else - echo 'ENV{ID_FS_TYPE}=="crypto_LUKS", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%k /sbin/cryptroot-ask $env{DEVNAME} luks-$env{ID_FS_UUID}"' \ + echo 'KERNEL="sda2", RUN+="/sbin/initqueue --unique --onetime --name cryptroot-ask-%k /sbin/cryptroot-ask $env{DEVNAME} osencrypt"' \ >> /etc/udev/rules.d/70-luks.rules fi echo 'LABEL="luks_end"' >> /etc/udev/rules.d/70-luks.rules ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/EncryptBlockDeviceState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import re import os import sys from inspect import ismethod from time import sleep from OSEncryptionState import * class EncryptBlockDeviceState(OSEncryptionState): def __init__(self, context): super(EncryptBlockDeviceState, self).__init__('EncryptBlockDeviceState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter encrypt_block_device state") if not super(EncryptBlockDeviceState, self).should_enter(): return False self.context.logger.log("Performing enter checks for encrypt_block_device state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering encrypt_block_device state") self.context.logger.log("Resizing " + self.rootfs_block_device) current_rootfs_size = self._get_root_fs_size_in_sectors(sector_size=512) desired_rootfs_size = current_rootfs_size - 8192 self.command_executor.Execute('resize2fs {0} {1}s'.format(self.rootfs_block_device, desired_rootfs_size), True) self.command_executor.Execute('mount /boot', False) # self._find_bek_and_execute_action('_dump_passphrase') self.context.hutil.do_status_report(operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message='OS disk encryption started') self._find_bek_and_execute_action('_luks_reencrypt') def should_exit(self): self.context.logger.log("Verifying if machine should exit encrypt_block_device state") if not os.path.exists('/dev/mapper/osencrypt'): self._find_bek_and_execute_action('_luks_open') self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return super(EncryptBlockDeviceState, self).should_exit() def _luks_open(self, bek_path): self.command_executor.Execute('cryptsetup luksOpen {0} osencrypt -d {1}'.format(self.rootfs_block_device, bek_path), raise_exception_on_failure=True) def _luks_reencrypt(self, bek_path): self.command_executor.ExecuteInBash('cat {0} | cryptsetup-reencrypt -N --reduce-device-size 8192s {1} -v'.format(bek_path, self.rootfs_block_device), raise_exception_on_failure=True) def _dump_passphrase(self, bek_path): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="od -c {0}".format(bek_path), raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Passphrase:") self.context.logger.log(proc_comm.stdout) def _find_bek_and_execute_action(self, callback_method_name): callback_method = getattr(self, callback_method_name) if not ismethod(callback_method): raise Exception("{0} is not a method".format(callback_method_name)) bek_path = self.bek_util.get_bek_passphrase_file(self.encryption_config) callback_method(bek_path) def _get_root_fs_size_in_sectors(self, sector_size): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="dumpe2fs -h {0}".format(self.rootfs_block_device), raise_exception_on_failure=True, communicator=proc_comm) root_fs_block_count = re.findall(r'Block count:\s*(\d+)', proc_comm.stdout) root_fs_block_size = re.findall(r'Block size:\s*(\d+)', proc_comm.stdout) if not root_fs_block_count or not root_fs_block_size: raise Exception("Error parsing dumpe2fs output, count={0}, size={1}".format(root_fs_block_count, root_fs_block_size)) root_fs_block_count = int(root_fs_block_count[0]) root_fs_block_size = int(root_fs_block_size[0]) return (root_fs_block_count * root_fs_block_size) / sector_size ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/PatchBootSystemState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import re import os import sys from time import sleep from OSEncryptionState import * class PatchBootSystemState(OSEncryptionState): def __init__(self, context): super(PatchBootSystemState, self).__init__('PatchBootSystemState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter patch_boot_system state") if not super(PatchBootSystemState, self).should_enter(): return False self.context.logger.log("Performing enter checks for patch_boot_system state") self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering patch_boot_system state") self.command_executor.Execute('mount /boot', False) self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.Execute('mkdir /oldroot/memroot', True) self.command_executor.Execute('pivot_root /oldroot /oldroot/memroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /memroot/$i /$i; done', True) try: self._modify_pivoted_oldroot() except Exception as e: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') raise else: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') extension_full_name = 'Microsoft.Azure.Security.' + CommonVariables.extension_name self.command_executor.Execute('cp -ax' + ' /var/log/azure/{0}'.format(extension_full_name) + ' /oldroot/var/log/azure/{0}.Stripdown'.format(extension_full_name), True) self.command_executor.Execute('umount /boot') self.command_executor.Execute('umount /oldroot') self.context.logger.log("Pivoted back into memroot successfully, restarting WALA") self.command_executor.Execute('service sshd restart') self.command_executor.Execute('service atd restart') with open("/restart-wala.sh", "w") as f: f.write("service waagent restart\n") with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.should_exit() self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit patch_boot_system state") return super(PatchBootSystemState, self).should_exit() def _append_contents_to_file(self, contents, path): with open(path, 'a') as f: f.write(contents) def _modify_pivoted_oldroot(self): self.context.logger.log("Pivoted into oldroot successfully") scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) patchesdir = os.path.join(scriptdir, '../encryptpatches') patchpath = os.path.join(patchesdir, 'rhel_68_dracut.patch') if not os.path.exists(patchpath): message = "Patch not found at path: {0}".format(patchpath) self.context.logger.log(message) raise Exception(message) else: self.context.logger.log("Patch found at path: {0}".format(patchpath)) self.disk_util.remove_mount_info('/') self.disk_util.append_mount_info('/dev/mapper/osencrypt', '/') self.command_executor.ExecuteInBash('patch -b -d /usr/share/dracut/modules.d/90crypt -p1 <{0}'.format(patchpath), True) self._append_contents_to_file('\nadd_drivers+=" fuse vfat nls_cp437 nls_iso8859-1"\n', '/etc/dracut.conf') self._append_contents_to_file('\nadd_dracutmodules+=" crypt"\n', '/etc/dracut.conf') self.command_executor.Execute('/sbin/dracut -f -v', True) with open("/boot/grub/grub.conf", "r") as f: contents = f.read() contents = re.sub(r"rd_NO_LUKS ", r"", contents) contents = re.sub(r"root=(.*?)\s", r"root=/dev/mapper/osencrypt rd_LUKS_UUID=osencrypt rdinitdebug ", contents) with open("/boot/grub/grub.conf", "w") as f: f.write(contents) ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/PrereqState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * from pprint import pprint class PrereqState(OSEncryptionState): def __init__(self, context): super(PrereqState, self).__init__('PrereqState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter prereq state") if not super(PrereqState, self).should_enter(): return False self.context.logger.log("Performing enter checks for prereq state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering prereq state") distro_info = self.context.distro_patcher.distro_info self.context.logger.log("Distro info: {0}".format(distro_info)) if ((distro_info[0] == 'redhat' and distro_info[1] == '6.8') or (distro_info[0] == 'centos' and distro_info[1] == '6.8')): self.context.logger.log("Enabling OS volume encryption on {0} {1}".format(distro_info[0], distro_info[1])) else: raise Exception("RHEL68EncryptionStateMachine called for distro {0} {1}".format(distro_info[0], distro_info[1])) self.context.distro_patcher.install_extras() self._patch_waagent() self.command_executor.Execute('telinit u', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit prereq state") return super(PrereqState, self).should_exit() def _patch_waagent(self): self.context.logger.log("Patching waagent") contents = None with open('/etc/waagent.conf', 'r') as f: contents = f.read() contents = re.sub(r'ResourceDisk.EnableSwap=.', 'ResourceDisk.EnableSwap=n', contents) with open('/etc/waagent.conf', 'w') as f: f.write(contents) self.context.logger.log("waagent patched successfully") ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/SelinuxState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * class SelinuxState(OSEncryptionState): def __init__(self, context): super(SelinuxState, self).__init__('SelinuxState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter selinux state") if not super(SelinuxState, self).should_enter(): return False self.context.logger.log("Performing enter checks for selinux state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering selinux state") se_linux_status = self.context.encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': self.context.logger.log("SELinux is in enforcing mode, disabling") self.context.encryption_environment.disable_se_linux() def should_exit(self): self.context.logger.log("Verifying if machine should exit selinux state") return super(SelinuxState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/StripdownState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from OSEncryptionState import * class StripdownState(OSEncryptionState): def __init__(self, context): super(StripdownState, self).__init__('StripdownState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter stripdown state") if not super(StripdownState, self).should_enter(): return False self.context.logger.log("Performing enter checks for stripdown state") self.command_executor.Execute('rm -rf /tmp/tmproot', True) self.command_executor.ExecuteInBash('! [ -e "/oldroot" ]', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering stripdown state") self.command_executor.Execute('umount -a') self.command_executor.Execute('mkdir /tmp/tmproot', True) self.command_executor.Execute('mount -t tmpfs none /tmp/tmproot', True) self.command_executor.ExecuteInBash('for i in proc sys dev run usr var tmp root oldroot boot; do mkdir /tmp/tmproot/$i; done', True) self.command_executor.ExecuteInBash('for i in bin etc mnt sbin lib lib64 root; do cp -ax /$i /tmp/tmproot/; done', True) self.command_executor.ExecuteInBash('for i in bin sbin libexec lib lib64 share; do cp -ax /usr/$i /tmp/tmproot/usr/; done', True) self.command_executor.ExecuteInBash('for i in lib local lock opt run spool tmp; do cp -ax /var/$i /tmp/tmproot/var/; done', True) self.command_executor.ExecuteInBash('mkdir /tmp/tmproot/var/log', True) self.command_executor.ExecuteInBash('cp -ax /var/log/azure /tmp/tmproot/var/log/', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.ExecuteInBash('[ -e "/tmp/tmproot/var/lib/azure_disk_encryption_config/azure_crypt_request_queue.ini" ]', True) self.command_executor.Execute('service waagent stop', True) self.command_executor.Execute('pivot_root /tmp/tmproot /tmp/tmproot/oldroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys; do mount --move /oldroot/$i /$i; done', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit stripdown state") if not os.path.exists(self.state_marker): self.context.logger.log("First call to stripdown state (pid={0}), restarting process".format(os.getpid())) # create the marker, but do not advance the state machine super(StripdownState, self).should_exit() # the restarted process shall see the marker and advance the state machine self.command_executor.Execute('service atd restart', True) os.chdir('/') with open("/restart-wala.sh", "w") as f: f.write("service waagent restart\n") self.command_executor.Execute('at -f /restart-wala.sh now + 1 minutes', True) self.context.hutil.do_exit(exit_code=CommonVariables.encryption_failed, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, code=CommonVariables.encryption_failed, message="Restarted extension from stripped down OS") else: self.context.logger.log("Second call to stripdown state (pid={0}), continuing process".format(os.getpid())) return True ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/UnmountOldrootState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import re import sys from time import sleep from OSEncryptionState import * class UnmountOldrootState(OSEncryptionState): def __init__(self, context): super(UnmountOldrootState, self).__init__('UnmountOldrootState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter unmount_oldroot state") if not super(UnmountOldrootState, self).should_enter(): return False self.context.logger.log("Performing enter checks for unmount_oldroot state") self.command_executor.ExecuteInBash('[ -e "/oldroot" ]', True) if self.command_executor.Execute('mountpoint /oldroot') != 0: return False return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering unmount_oldroot state") self.command_executor.ExecuteInBash('mkdir -p /var/empty/sshd', True) self.command_executor.Execute('service sshd restart') self.command_executor.Execute('dhclient') proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="/sbin/service --status-all", raise_exception_on_failure=True, communicator=proc_comm) for line in proc_comm.stdout.split('\n'): if not "running" in line: continue if "waagent" in line or "ssh" in line: continue splitted = line.split() if len(splitted): service = splitted[0] self.command_executor.Execute('service {0} restart'.format(service)) self.command_executor.Execute('swapoff -a', True) self.bek_util.umount_azure_passhprase(self.encryption_config, force=True) if os.path.exists("/oldroot/mnt/resource"): self.command_executor.Execute('umount /oldroot/mnt/resource') if os.path.exists("/oldroot/mnt"): self.command_executor.Execute('umount /oldroot/mnt') if os.path.exists("/oldroot/mnt/azure_bek_disk"): self.command_executor.Execute('umount /oldroot/mnt/azure_bek_disk') if os.path.exists("/mnt"): self.command_executor.Execute('umount /mnt') if os.path.exists("/mnt/azure_bek_disk"): self.command_executor.Execute('umount /mnt/azure_bek_disk') self.command_executor.Execute('umount /oldroot/mnt/resource') self.command_executor.Execute('umount /oldroot/boot') self.command_executor.Execute('umount /oldroot/misc') self.command_executor.Execute('umount /oldroot/net') self.command_executor.Execute('telinit u', True) self.command_executor.Execute('kill 1', True) proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="fuser -vm /oldroot", raise_exception_on_failure=False, communicator=proc_comm) self.context.logger.log("Processes using oldroot:\n{0}".format(proc_comm.stdout)) procs_to_kill = filter(lambda p: p.isdigit(), proc_comm.stdout.split()) procs_to_kill = reversed(sorted(procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("Restarting WALA before committing suicide") self.context.logger.log("Current executable path: " + sys.executable) self.context.logger.log("Current executable arguments: " + " ".join(sys.argv)) # Kill any other daemons that are blocked and would be executed after this process commits # suicide self.command_executor.Execute('service atd restart') os.chdir('/') with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) sleep(3) self.command_executor.ExecuteInBash('for mp in `grep /oldroot /proc/mounts | cut -f2 -d\' \' | sort -r`; do umount $mp; done', True) sleep(3) attempt = 1 while True: if attempt > 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Attempt #{0} for reloading udev rules".format(attempt)) self.command_executor.ExecuteInBash('pkill -f .*udev.*') self.command_executor.ExecuteInBash('udevd &') self.command_executor.ExecuteInBash('udevadm control --reload-rules && sleep 3') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 self.command_executor.Execute('e2fsck -yf {0}'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_68/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from SelinuxState import * from StripdownState import * from UnmountOldrootState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/RHEL72EncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class RHEL72EncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='selinux', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_selinux', 'source': 'prereq', 'dest': 'selinux', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_stripdown', 'source': 'selinux', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'unmount_oldroot', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(RHEL72EncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(RHEL72EncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(RHEL72EncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'selinux': SelinuxState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=RHEL72EncryptionStateMachine.states, transitions=RHEL72EncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="mount", raise_exception_on_failure=True, communicator=proc_comm) if '/dev/mapper/osencrypt' in proc_comm.stdout: self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_selinux() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from RHEL72EncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/EncryptBlockDeviceState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from inspect import ismethod from time import sleep from OSEncryptionState import * from distutils.version import LooseVersion class EncryptBlockDeviceState(OSEncryptionState): def __init__(self, context): super(EncryptBlockDeviceState, self).__init__('EncryptBlockDeviceState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter encrypt_block_device state") if not super(EncryptBlockDeviceState, self).should_enter(): return False self.context.logger.log("Performing enter checks for encrypt_block_device state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering encrypt_block_device state") self.command_executor.Execute('mount /boot', False) # self._find_bek_and_execute_action('_dump_passphrase') self._find_bek_and_execute_action('_luks_format') self._find_bek_and_execute_action('_luks_open') self.context.hutil.do_status_report(operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message='OS disk encryption started') # Enable used space encryption on RHEL 7.3 and above distro_info = self.context.distro_patcher.distro_info if LooseVersion(distro_info[1]) >= LooseVersion('7.3'): self.command_executor.Execute('dd if={0} of=/dev/mapper/osencrypt conv=sparse bs=64K'.format(self.rootfs_block_device), True) else: self.command_executor.Execute('dd if={0} of=/dev/mapper/osencrypt bs=52428800'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit encrypt_block_device state") if not os.path.exists('/dev/mapper/osencrypt'): self._find_bek_and_execute_action('_luks_open') self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return super(EncryptBlockDeviceState, self).should_exit() def _luks_format(self, bek_path): self.command_executor.Execute('mkdir /boot/luks', True) self.command_executor.Execute('dd if=/dev/zero of=/boot/luks/osluksheader bs=33554432 count=1', True) self.command_executor.Execute('cryptsetup luksFormat --header /boot/luks/osluksheader -d {0} {1} -q'.format(bek_path, self.rootfs_block_device), raise_exception_on_failure=True) def _luks_open(self, bek_path): self.command_executor.Execute('cryptsetup luksOpen --header /boot/luks/osluksheader {0} osencrypt -d {1}'.format(self.rootfs_block_device, bek_path), raise_exception_on_failure=True) def _dump_passphrase(self, bek_path): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="od -c {0}".format(bek_path), raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Passphrase:") self.context.logger.log(proc_comm.stdout) def _find_bek_and_execute_action(self, callback_method_name): callback_method = getattr(self, callback_method_name) if not ismethod(callback_method): raise Exception("{0} is not a method".format(callback_method_name)) bek_path = self.bek_util.get_bek_passphrase_file(self.encryption_config) callback_method(bek_path) ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/PatchBootSystemState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import re import sys from time import sleep from OSEncryptionState import * class PatchBootSystemState(OSEncryptionState): def __init__(self, context): super(PatchBootSystemState, self).__init__('PatchBootSystemState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter patch_boot_system state") if not super(PatchBootSystemState, self).should_enter(): return False self.context.logger.log("Performing enter checks for patch_boot_system state") self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('umount /oldroot', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering patch_boot_system state") self.command_executor.Execute('mount /boot', False) self.command_executor.Execute('mount /dev/mapper/osencrypt /oldroot', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.Execute('mkdir /oldroot/memroot', True) self.command_executor.Execute('pivot_root /oldroot /oldroot/memroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /memroot/$i /$i; done', True) self.command_executor.ExecuteInBash('[ -e "/boot/luks" ]', True) try: self._modify_pivoted_oldroot() except Exception as e: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') raise else: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') extension_full_name = 'Microsoft.Azure.Security.' + CommonVariables.extension_name self.command_executor.Execute('cp -ax' + ' /var/log/azure/{0}'.format(extension_full_name) + ' /oldroot/var/log/azure/{0}.Stripdown'.format(extension_full_name)) self.command_executor.Execute('umount /boot') self.command_executor.Execute('umount /oldroot') self.command_executor.Execute('systemctl restart waagent') self.context.logger.log("Pivoted back into memroot successfully") def should_exit(self): self.context.logger.log("Verifying if machine should exit patch_boot_system state") return super(PatchBootSystemState, self).should_exit() def _append_contents_to_file(self, contents, path): with open(path, 'a') as f: f.write(contents) def _modify_pivoted_oldroot(self): self.context.logger.log("Pivoted into oldroot successfully") scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) ademoduledir = os.path.join(scriptdir, '../../91ade') dracutmodulesdir = '/lib/dracut/modules.d' udevaderulepath = os.path.join(dracutmodulesdir, '91ade/50-udev-ade.rules') proc_comm = ProcessCommunicator() self.command_executor.Execute('cp -r {0} /lib/dracut/modules.d/'.format(ademoduledir), True) udevadm_cmd = "udevadm info --attribute-walk --name={0}".format(self.rootfs_block_device) self.command_executor.Execute(command_to_execute=udevadm_cmd, raise_exception_on_failure=True, communicator=proc_comm) matches = re.findall(r'ATTR{partition}=="(.*)"', proc_comm.stdout) if not matches: raise Exception("Could not parse ATTR{partition} from udevadm info") partition = matches[0] sed_cmd = 'sed -i.bak s/ENCRYPTED_DISK_PARTITION/{0}/ "{1}"'.format(partition, udevaderulepath) self.command_executor.Execute(command_to_execute=sed_cmd, raise_exception_on_failure=True) self._append_contents_to_file('\nGRUB_CMDLINE_LINUX+=" rd.debug"\n', '/etc/default/grub') self._append_contents_to_file('\nadd_drivers+=" fuse vfat nls_cp437 nls_iso8859-1"\n', '/etc/dracut.conf') self._append_contents_to_file('\nadd_dracutmodules+=" crypt"\n', '/etc/dracut.conf') self.command_executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) self.command_executor.Execute('grub2-install --recheck --force {0}'.format(self.rootfs_disk), True) self.command_executor.Execute('grub2-mkconfig -o /boot/grub2/grub.cfg', True) ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/PrereqState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * from pprint import pprint class PrereqState(OSEncryptionState): def __init__(self, context): super(PrereqState, self).__init__('PrereqState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter prereq state") if not super(PrereqState, self).should_enter(): return False self.context.logger.log("Performing enter checks for prereq state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering prereq state") distro_info = self.context.distro_patcher.distro_info self.context.logger.log("Distro info: {0}".format(distro_info)) if ((distro_info[0] == 'redhat' and distro_info[1] == '7.2') or (distro_info[0] == 'redhat' and distro_info[1] == '7.3') or (distro_info[0] == 'redhat' and distro_info[1] == '7.4') or (distro_info[0] == 'redhat' and distro_info[1] == '7.5') or (distro_info[0] == 'redhat' and distro_info[1] == '7.6') or (distro_info[0] == 'redhat' and distro_info[1] == '7.7') or (distro_info[0] == 'centos' and distro_info[1].startswith('7.7')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.6')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.5')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.4')) or (distro_info[0] == 'centos' and distro_info[1] == '7.3.1611') or (distro_info[0] == 'centos' and distro_info[1] == '7.2.1511')): self.context.logger.log("Enabling OS volume encryption on {0} {1}".format(distro_info[0], distro_info[1])) else: raise Exception("RHEL72EncryptionStateMachine called for distro {0} {1}".format(distro_info[0], distro_info[1])) self.context.distro_patcher.install_extras() self._patch_waagent() self.command_executor.Execute('systemctl daemon-reload', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit prereq state") return super(PrereqState, self).should_exit() def _patch_waagent(self): self.context.logger.log("Patching waagent") contents = None with open('/usr/lib/systemd/system/waagent.service', 'r') as f: contents = f.read() contents = re.sub(r'\[Service\]\n', '[Service]\nKillMode=process\n', contents) with open('/usr/lib/systemd/system/waagent.service', 'w') as f: f.write(contents) self.context.logger.log("waagent patched successfully") ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/SelinuxState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * class SelinuxState(OSEncryptionState): def __init__(self, context): super(SelinuxState, self).__init__('SelinuxState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter selinux state") if not super(SelinuxState, self).should_enter(): return False self.context.logger.log("Performing enter checks for selinux state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering selinux state") se_linux_status = self.context.encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': self.context.logger.log("SELinux is in enforcing mode, disabling") self.context.encryption_environment.disable_se_linux() def should_exit(self): self.context.logger.log("Verifying if machine should exit selinux state") return super(SelinuxState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/StripdownState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from OSEncryptionState import * class StripdownState(OSEncryptionState): def __init__(self, context): super(StripdownState, self).__init__('StripdownState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter stripdown state") if not super(StripdownState, self).should_enter(): return False self.context.logger.log("Performing enter checks for stripdown state") self.command_executor.Execute('rm -rf /tmp/tmproot', True) self.command_executor.ExecuteInBash('! [ -e "/oldroot" ]', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering stripdown state") self.command_executor.Execute('umount -a') self.command_executor.Execute('mkdir /tmp/tmproot', True) self.command_executor.Execute('mount -t tmpfs none /tmp/tmproot', True) self.command_executor.ExecuteInBash('for i in proc sys dev run usr var tmp root oldroot boot; do mkdir /tmp/tmproot/$i; done', True) self.command_executor.ExecuteInBash('for i in bin etc mnt sbin lib lib64 root; do cp -ax /$i /tmp/tmproot/; done', True) self.command_executor.ExecuteInBash('for i in bin sbin libexec lib lib64 share; do cp -ax /usr/$i /tmp/tmproot/usr/; done', True) self.command_executor.ExecuteInBash('for i in lib local lock opt run spool tmp; do cp -ax /var/$i /tmp/tmproot/var/; done', True) self.command_executor.ExecuteInBash('mkdir /tmp/tmproot/var/log', True) self.command_executor.ExecuteInBash('cp -ax /var/log/azure /tmp/tmproot/var/log/', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.ExecuteInBash('[ -e "/tmp/tmproot/var/lib/azure_disk_encryption_config/azure_crypt_request_queue.ini" ]', True) self.command_executor.Execute('systemctl stop waagent', True) self.command_executor.Execute('pivot_root /tmp/tmproot /tmp/tmproot/oldroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys run; do mount --move /oldroot/$i /$i; done', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit stripdown state") if not os.path.exists(self.state_marker): self.context.logger.log("First call to stripdown state (pid={0}), restarting process".format(os.getpid())) # create the marker, but do not advance the state machine super(StripdownState, self).should_exit() # the restarted process shall see the marker and advance the state machine self.command_executor.ExecuteInBash('sleep 30 && systemctl start waagent &', True) self.context.hutil.do_exit(exit_code=0, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message="Restarted extension from stripped down OS") else: self.context.logger.log("Second call to stripdown state (pid={0}), continuing process".format(os.getpid())) return True ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/UnmountOldrootState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import re import sys from time import sleep from OSEncryptionState import * class UnmountOldrootState(OSEncryptionState): def __init__(self, context): super(UnmountOldrootState, self).__init__('UnmountOldrootState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter unmount_oldroot state") if not super(UnmountOldrootState, self).should_enter(): return False self.context.logger.log("Performing enter checks for unmount_oldroot state") self.command_executor.ExecuteInBash('[ -e "/oldroot" ]', True) if self.command_executor.Execute('mountpoint /oldroot') != 0: return False return True def restart_systemd_services(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="systemctl list-units", raise_exception_on_failure=True, communicator=proc_comm) for line in proc_comm.stdout.split('\n'): if not "running" in line: continue if "waagent.service" in line or "sshd.service" in line: continue match = re.search(r'\s(\S*?\.service)', line) if match: service = match.groups()[0] self.command_executor.Execute('systemctl restart {0}'.format(service)) def enter(self): if not self.should_enter(): return self.context.logger.log("Entering unmount_oldroot state") self.command_executor.ExecuteInBash('mkdir -p /var/empty/sshd', True) self.command_executor.ExecuteInBash('systemctl restart sshd.service') self.command_executor.ExecuteInBash('dhclient') self.restart_systemd_services() self.command_executor.Execute('swapoff -a', True) if os.path.exists("/oldroot/mnt/resource"): self.command_executor.Execute('umount /oldroot/mnt/resource') proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="fuser -vm /oldroot", raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Processes using oldroot:\n{0}".format(proc_comm.stdout)) procs_to_kill = filter(lambda p: p.isdigit(), proc_comm.stdout.split()) procs_to_kill = reversed(sorted(procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("Restarting WALA in 30 seconds before committing suicide") # This is a workaround for the bug on CentOS/RHEL 7.2 where systemd-udevd # needs to be restarted and the drive mounted/unmounted. # Otherwise the dir becomes inaccessible, fuse says: Transport endpoint is not connected self.command_executor.Execute('systemctl restart systemd-udevd', True) self.bek_util.umount_azure_passhprase(self.encryption_config, force=True) self.command_executor.Execute('systemctl restart systemd-udevd', True) self.bek_util.get_bek_passphrase_file(self.encryption_config) self.bek_util.umount_azure_passhprase(self.encryption_config, force=True) self.command_executor.Execute('systemctl restart systemd-udevd', True) self.command_executor.ExecuteInBash('sleep 30 && systemctl start waagent &', True) if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) # Re-execute systemd, get pid 1 to use the new root self.command_executor.Execute('telinit u', True) sleep(3) self.command_executor.Execute('umount /oldroot', True) self.restart_systemd_services() # # With the recent release of 7.4 it was found that even after unmounting # oldroot, there were some open handles to the root file system block device. # The below logic tries to find the offending mount by grepping /proc/*/task/*/mountinfo # and kill the respective processes so that encryption can proceed # proc_comm = ProcessCommunicator() # Example: grep for /dev/sda2 in the files /proc/*task/*/mountinfo and remove results of the grep process itself. # If grep -v grep is not applied, then the command throws an exception self.command_executor.ExecuteInBash( command_to_execute="grep {0} /proc/*/task/*/mountinfo | grep -v grep".format(self.rootfs_sdx_path), raise_exception_on_failure=False, communicator=proc_comm) procs_to_kill = filter(lambda path: path.startswith('/proc/'), proc_comm.stdout.split()) procs_to_kill = map(lambda path: int(path.split('/')[2]), procs_to_kill) procs_to_kill = list(reversed(sorted(procs_to_kill))) self.context.logger.log("Processes with tasks using {0}:\n{1}".format(self.rootfs_sdx_path, procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("This extension is holding on to {0}. " "This is not expected...".format(self.rootfs_sdx_path)) continue if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) sleep(3) attempt = 1 while True: if attempt > 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Attempt #{0} for restarting systemd-udevd".format(attempt)) self.command_executor.Execute('systemctl restart systemd-udevd') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 sleep(3) self.command_executor.Execute('xfs_repair {0}'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from SelinuxState import * from StripdownState import * from UnmountOldrootState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/RHEL72LVMEncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class RHEL72LVMEncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='selinux', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_selinux', 'source': 'prereq', 'dest': 'selinux', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_stripdown', 'source': 'selinux', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'unmount_oldroot', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(RHEL72LVMEncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(RHEL72LVMEncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(RHEL72LVMEncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'selinux': SelinuxState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=RHEL72LVMEncryptionStateMachine.states, transitions=RHEL72LVMEncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="pvdisplay", raise_exception_on_failure=True, communicator=proc_comm) patch_boot_system_state_marker = os.path.join(self.encryption_environment.os_encryption_markers_path, 'PatchBootSystemState') if '/dev/mapper/osencrypt' in proc_comm.stdout and os.path.exists(patch_boot_system_state_marker): self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_selinux() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from RHEL72LVMEncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/EncryptBlockDeviceState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from inspect import ismethod from time import sleep from OSEncryptionState import * class EncryptBlockDeviceState(OSEncryptionState): def __init__(self, context): super(EncryptBlockDeviceState, self).__init__('EncryptBlockDeviceState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter encrypt_block_device state") if not super(EncryptBlockDeviceState, self).should_enter(): return False self.context.logger.log("Performing enter checks for encrypt_block_device state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering encrypt_block_device state") self.command_executor.Execute('mount /boot', False) # self._find_bek_and_execute_action('_dump_passphrase') self._find_bek_and_execute_action('_luks_format') self._find_bek_and_execute_action('_luks_open') self.context.hutil.do_status_report(operation='EnableEncryptionDataVolumes', status=CommonVariables.extension_success_status, status_code=str(CommonVariables.success), message='OS disk encryption started') self.command_executor.Execute('dd if={0} of=/dev/mapper/osencrypt conv=sparse bs=64K'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit encrypt_block_device state") if not os.path.exists('/dev/mapper/osencrypt'): self._find_bek_and_execute_action('_luks_open') return super(EncryptBlockDeviceState, self).should_exit() def _luks_format(self, bek_path): self.command_executor.Execute('mkdir /boot/luks', True) self.command_executor.Execute('dd if=/dev/zero of=/boot/luks/osluksheader bs=33554432 count=1', True) self.command_executor.Execute('cryptsetup luksFormat --header /boot/luks/osluksheader -d {0} {1} -q'.format(bek_path, self.rootfs_block_device), raise_exception_on_failure=True) def _luks_open(self, bek_path): self.command_executor.Execute('cryptsetup luksOpen --header /boot/luks/osluksheader {0} osencrypt -d {1}'.format(self.rootfs_block_device, bek_path), raise_exception_on_failure=True) def _dump_passphrase(self, bek_path): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="od -c {0}".format(bek_path), raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Passphrase:") self.context.logger.log(proc_comm.stdout) def _find_bek_and_execute_action(self, callback_method_name): callback_method = getattr(self, callback_method_name) if not ismethod(callback_method): raise Exception("{0} is not a method".format(callback_method_name)) bek_path = self.bek_util.get_bek_passphrase_file(self.encryption_config) callback_method(bek_path) ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/PatchBootSystemState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys from inspect import ismethod from time import sleep from OSEncryptionState import * class PatchBootSystemState(OSEncryptionState): def __init__(self, context): super(PatchBootSystemState, self).__init__('PatchBootSystemState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter patch_boot_system state") if not super(PatchBootSystemState, self).should_enter(): return False self.context.logger.log("Performing enter checks for patch_boot_system state") if not os.path.exists('/dev/mapper/osencrypt'): return False return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering patch_boot_system state") self.command_executor.Execute('systemctl restart lvm2-lvmetad', True) self.command_executor.Execute('pvscan', True) self.command_executor.Execute('vgcfgrestore -f /volumes.lvm rootvg', True) self.command_executor.Execute('cryptsetup luksClose osencrypt', True) self._find_bek_and_execute_action('_luks_open') self.unmount_lvm_volumes() self.command_executor.Execute('mount /dev/rootvg/rootlv /oldroot', True) self.command_executor.Execute('mount /dev/rootvg/varlv /oldroot/var', True) self.command_executor.Execute('mount /dev/rootvg/usrlv /oldroot/usr', True) self.command_executor.Execute('mount /dev/rootvg/tmplv /oldroot/tmp', True) self.command_executor.Execute('mount /dev/rootvg/homelv /oldroot/home', True) self.command_executor.Execute('mount /dev/rootvg/optlv /oldroot/opt', True) self.command_executor.Execute('mount /boot', False) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.Execute('mkdir /oldroot/memroot', True) self.command_executor.Execute('pivot_root /oldroot /oldroot/memroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /memroot/$i /$i; done', True) self.command_executor.ExecuteInBash('[ -e "/boot/luks" ]', True) try: self._modify_pivoted_oldroot() except Exception as e: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') raise else: self.command_executor.Execute('mount --make-rprivate /') self.command_executor.Execute('pivot_root /memroot /memroot/oldroot') self.command_executor.Execute('rmdir /oldroot/memroot') self.command_executor.ExecuteInBash('for i in dev proc sys boot; do mount --move /oldroot/$i /$i; done') extension_full_name = 'Microsoft.Azure.Security.' + CommonVariables.extension_name self.command_executor.Execute('/bin/cp -ax' + ' /var/log/azure/{0}'.format(extension_full_name) + ' /oldroot/var/log/azure/{0}.Stripdown'.format(extension_full_name)) self.command_executor.ExecuteInBash('/bin/cp -ax' + ' /var/lib/azure_disk_encryption_config/os_encryption_markers/*' + ' /oldroot/var/lib/azure_disk_encryption_config/os_encryption_markers/', True) self.command_executor.Execute('touch /oldroot/var/lib/azure_disk_encryption_config/os_encryption_markers/PatchBootSystemState', True) self.command_executor.Execute('umount /boot') self.command_executor.Execute('umount /oldroot') self.command_executor.Execute('systemctl restart waagent') self.context.logger.log("Pivoted back into memroot successfully") self.unmount_lvm_volumes() def should_exit(self): self.context.logger.log("Verifying if machine should exit patch_boot_system state") return super(PatchBootSystemState, self).should_exit() def unmount_lvm_volumes(self): self.command_executor.Execute('swapoff -a', True) self.command_executor.Execute('umount -a') for mountpoint in ['/var', '/opt', '/tmp', '/home', '/usr']: if self.command_executor.Execute('mountpoint /oldroot' + mountpoint) == 0: self.unmount('/oldroot' + mountpoint) if self.command_executor.Execute('mountpoint ' + mountpoint) == 0: self.unmount(mountpoint) self.unmount_var() def unmount_var(self): unmounted = False while not unmounted: self.command_executor.Execute('systemctl stop NetworkManager') self.command_executor.Execute('systemctl stop rsyslog') self.command_executor.Execute('systemctl stop systemd-udevd') self.command_executor.Execute('systemctl stop systemd-journald') self.command_executor.Execute('systemctl stop systemd-hostnamed') self.command_executor.Execute('systemctl stop atd') self.command_executor.Execute('systemctl stop postfix') self.unmount('/var') sleep(3) if self.command_executor.Execute('mountpoint /var'): unmounted = True def unmount(self, mountpoint): if mountpoint != '/var': self.unmount_var() if self.command_executor.Execute("mountpoint " + mountpoint): return proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="fuser -vm " + mountpoint, raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Processes using {0}:\n{1}".format(mountpoint, proc_comm.stdout)) procs_to_kill = filter(lambda p: p.isdigit(), proc_comm.stdout.split()) procs_to_kill = reversed(sorted(procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("Restarting WALA before committing suicide") self.context.logger.log("Current executable path: " + sys.executable) self.context.logger.log("Current executable arguments: " + " ".join(sys.argv)) # Kill any other daemons that are blocked and would be executed after this process commits # suicide self.command_executor.Execute('systemctl restart atd') os.chdir('/') with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) self.command_executor.Execute('telinit u', True) sleep(3) self.command_executor.Execute('umount ' + mountpoint, True) def _append_contents_to_file(self, contents, path): with open(path, 'a') as f: f.write(contents) def _modify_pivoted_oldroot(self): self.context.logger.log("Pivoted into oldroot successfully") scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) ademoduledir = os.path.join(scriptdir, '../../91ade') dracutmodulesdir = '/lib/dracut/modules.d' udevaderulepath = os.path.join(dracutmodulesdir, '91ade/50-udev-ade.rules') proc_comm = ProcessCommunicator() self.command_executor.Execute('cp -r {0} /lib/dracut/modules.d/'.format(ademoduledir), True) udevadm_cmd = "udevadm info --attribute-walk --name={0}".format(self.rootfs_block_device) self.command_executor.Execute(command_to_execute=udevadm_cmd, raise_exception_on_failure=True, communicator=proc_comm) matches = re.findall(r'ATTR{partition}=="(.*)"', proc_comm.stdout) if not matches: raise Exception("Could not parse ATTR{partition} from udevadm info") partition = matches[0] sed_cmd = 'sed -i.bak s/ENCRYPTED_DISK_PARTITION/{0}/ "{1}"'.format(partition, udevaderulepath) self.command_executor.Execute(command_to_execute=sed_cmd, raise_exception_on_failure=True) self._append_contents_to_file('\nGRUB_CMDLINE_LINUX+=" rd.debug"\n', '/etc/default/grub') self._append_contents_to_file('\nadd_drivers+=" fuse vfat nls_cp437 nls_iso8859-1"\n', '/etc/dracut.conf') self._append_contents_to_file('\nadd_dracutmodules+=" crypt"\n', '/etc/dracut.conf') self.command_executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) self.command_executor.Execute('grub2-install --recheck --force {0}'.format(self.rootfs_disk), True) self.command_executor.Execute('grub2-mkconfig -o /boot/grub2/grub.cfg', True) def _luks_open(self, bek_path): self.command_executor.Execute('mount /boot') self.command_executor.Execute('cryptsetup luksOpen --header /boot/luks/osluksheader {0} osencrypt -d {1}'.format(self.rootfs_block_device, bek_path), raise_exception_on_failure=True) def _find_bek_and_execute_action(self, callback_method_name): callback_method = getattr(self, callback_method_name) if not ismethod(callback_method): raise Exception("{0} is not a method".format(callback_method_name)) bek_path = self.bek_util.get_bek_passphrase_file(self.encryption_config) callback_method(bek_path) ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/PrereqState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * from pprint import pprint class PrereqState(OSEncryptionState): def __init__(self, context): super(PrereqState, self).__init__('PrereqState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter prereq state") if not super(PrereqState, self).should_enter(): return False self.context.logger.log("Performing enter checks for prereq state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering prereq state") distro_info = self.context.distro_patcher.distro_info self.context.logger.log("Distro info: {0}".format(distro_info)) if (((distro_info[0] == 'centos' and distro_info[1] == '7.3.1611') or (distro_info[0] == 'centos' and distro_info[1].startswith('7.4')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.5')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.6')) or (distro_info[0] == 'centos' and distro_info[1].startswith('7.7')) or (distro_info[0] == 'redhat' and distro_info[1] == '7.3') or (distro_info[0] == 'redhat' and distro_info[1] == '7.4') or (distro_info[0] == 'redhat' and distro_info[1] == '7.5') or (distro_info[0] == 'redhat' and distro_info[1] == '7.6') or (distro_info[0] == 'redhat' and distro_info[1] == '7.7')) and self.disk_util.is_os_disk_lvm()): self.context.logger.log("Enabling OS volume encryption on {0} {1}".format(distro_info[0], distro_info[1])) else: raise Exception("RHEL72LVMEncryptionStateMachine called for distro {0} {1}".format(distro_info[0], distro_info[1])) self.context.distro_patcher.install_extras() self._patch_waagent() self.command_executor.Execute('systemctl daemon-reload', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit prereq state") return super(PrereqState, self).should_exit() def _patch_waagent(self): self.context.logger.log("Patching waagent") contents = None with open('/usr/lib/systemd/system/waagent.service', 'r') as f: contents = f.read() contents = re.sub(r'\[Service\]\n', '[Service]\nKillMode=process\n', contents) with open('/usr/lib/systemd/system/waagent.service', 'w') as f: f.write(contents) self.context.logger.log("waagent patched successfully") ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/SelinuxState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # from OSEncryptionState import * class SelinuxState(OSEncryptionState): def __init__(self, context): super(SelinuxState, self).__init__('SelinuxState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter selinux state") if not super(SelinuxState, self).should_enter(): return False self.context.logger.log("Performing enter checks for selinux state") return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering selinux state") se_linux_status = self.context.encryption_environment.get_se_linux() if se_linux_status.lower() == 'enforcing': self.context.logger.log("SELinux is in enforcing mode, disabling") self.context.encryption_environment.disable_se_linux() def should_exit(self): self.context.logger.log("Verifying if machine should exit selinux state") return super(SelinuxState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/StripdownState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import sys from OSEncryptionState import * from time import sleep class StripdownState(OSEncryptionState): def __init__(self, context): super(StripdownState, self).__init__('StripdownState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter stripdown state") if not super(StripdownState, self).should_enter(): return False self.context.logger.log("Performing enter checks for stripdown state") self.command_executor.Execute('rm -rf /usr/tmproot', True) self.command_executor.ExecuteInBash('! [ -e "/oldroot" ]', True) return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering stripdown state") self.command_executor.Execute('swapoff -a') self.command_executor.Execute('umount -a') self.command_executor.Execute('mkdir /usr/tmproot', True) self.command_executor.Execute('mount -t tmpfs none /usr/tmproot', True) self.command_executor.ExecuteInBash('for i in proc sys dev run usr var tmp root oldroot boot; do mkdir /usr/tmproot/$i; done', True) self.command_executor.ExecuteInBash('for i in bin etc mnt sbin lib lib64 root; do cp -ax /$i /usr/tmproot/; done', True) self.command_executor.ExecuteInBash('for i in bin sbin libexec lib lib64 share; do cp -ax /usr/$i /usr/tmproot/usr/; done', True) self.command_executor.ExecuteInBash('for i in lib local lock opt run spool tmp; do cp -ax /var/$i /usr/tmproot/var/; done', True) self.command_executor.ExecuteInBash('mkdir /usr/tmproot/var/log', True) self.command_executor.ExecuteInBash('cp -ax /var/log/azure /usr/tmproot/var/log/', True) self.command_executor.Execute('mount --make-rprivate /', True) self.command_executor.ExecuteInBash('[ -e "/usr/tmproot/var/lib/azure_disk_encryption_config/azure_crypt_request_queue.ini" ]', True) self.command_executor.Execute('systemctl stop waagent', True) self.command_executor.Execute('pivot_root /usr/tmproot /usr/tmproot/oldroot', True) self.command_executor.ExecuteInBash('for i in dev proc sys run; do mount --move /oldroot/$i /$i; done', True) def should_exit(self): self.context.logger.log("Verifying if machine should exit stripdown state") if not os.path.exists(self.state_marker): self.context.logger.log("First call to stripdown state (pid={0}), restarting process".format(os.getpid())) # create the marker, but do not advance the state machine super(StripdownState, self).should_exit() self.command_executor.ExecuteInBash('rm -f /run/systemd/generator/*.mount', True) self.command_executor.ExecuteInBash('rm -f /run/systemd/generator/local-fs.target.requires/*.mount', True) self.command_executor.Execute("sed -i.bak '/rootvg/d' /etc/fstab", True) self.command_executor.Execute('telinit u', True) sleep(10) if self.command_executor.Execute('mountpoint /var') == 0: self.command_executor.Execute('umount /var', True) # the restarted process shall see the marker and advance the state machine self.command_executor.Execute('systemctl restart atd', True) os.chdir('/') with open("/restart-wala.sh", "w") as f: f.write("systemctl restart waagent\n") self.command_executor.Execute('at -f /restart-wala.sh now + 1 minutes', True) self.context.hutil.do_exit(exit_code=0, operation='EnableEncryptionOSVolume', status=CommonVariables.extension_success_status, code=str(CommonVariables.success), message="Restarted extension from stripped down OS") else: self.context.logger.log("Second call to stripdown state (pid={0}), continuing process".format(os.getpid())) return True ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/UnmountOldrootState.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import os import re import sys from time import sleep from OSEncryptionState import * class UnmountOldrootState(OSEncryptionState): def __init__(self, context): super(UnmountOldrootState, self).__init__('UnmountOldrootState', context) def should_enter(self): self.context.logger.log("Verifying if machine should enter unmount_oldroot state") if not super(UnmountOldrootState, self).should_enter(): return False self.context.logger.log("Performing enter checks for unmount_oldroot state") self.command_executor.ExecuteInBash('[ -e "/oldroot" ]', True) if self.command_executor.Execute('mountpoint /oldroot') != 0: return False return True def enter(self): if not self.should_enter(): return self.context.logger.log("Entering unmount_oldroot state") self.unmount_var() self.command_executor.ExecuteInBash('mkdir -p /var/empty/sshd', True) self.command_executor.ExecuteInBash('systemctl restart sshd.service') self.command_executor.ExecuteInBash('dhclient') proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="systemctl list-units", raise_exception_on_failure=True, communicator=proc_comm) for line in proc_comm.stdout.split('\n'): if not "running" in line: continue if "waagent.service" in line or "sshd.service" in line or "journald.service" in line: continue match = re.search(r'\s(\S*?\.service)', line) if match: service = match.groups()[0] self.command_executor.Execute('systemctl restart {0}'.format(service)) self.command_executor.Execute('swapoff -a', True) if os.path.exists("/oldroot/mnt/resource"): self.command_executor.Execute('umount /oldroot/mnt/resource') sleep(3) self.unmount('/oldroot/opt') self.unmount('/oldroot/var') self.unmount('/oldroot/usr') self.unmount('/oldroot') attempt = 1 while True: if attempt > 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Attempt #{0} for restarting systemd-udevd".format(attempt)) self.command_executor.Execute('systemctl restart systemd-udevd') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 self.unmount_var() sleep(3) self.command_executor.Execute('vgcfgbackup -f /volumes.lvm rootvg', True) self.command_executor.Execute('sed -i.bak \'s/sda2/mapper\/osencrypt/g\' /volumes.lvm', True) self.command_executor.Execute('lvremove -f rootvg', True) self.command_executor.Execute('vgremove rootvg', True) def unmount_var(self): unmounted = False while not unmounted: self.command_executor.Execute('systemctl stop NetworkManager') self.command_executor.Execute('systemctl stop rsyslog') self.command_executor.Execute('systemctl stop systemd-udevd') self.command_executor.Execute('systemctl stop systemd-journald') self.command_executor.Execute('systemctl stop systemd-hostnamed') self.command_executor.Execute('systemctl stop atd') self.command_executor.Execute('systemctl stop postfix') self.unmount('/var') sleep(3) if self.command_executor.Execute('mountpoint /var'): unmounted = True def unmount(self, mountpoint, call_unmount_var=True): if mountpoint != '/var': self.unmount_var() if self.command_executor.Execute("mountpoint " + mountpoint): return proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="fuser -vm " + mountpoint, raise_exception_on_failure=True, communicator=proc_comm) self.context.logger.log("Processes using {0}:\n{1}".format(mountpoint, proc_comm.stdout)) procs_to_kill = filter(lambda p: p.isdigit(), proc_comm.stdout.split()) procs_to_kill = reversed(sorted(procs_to_kill)) for victim in procs_to_kill: if int(victim) == os.getpid(): self.context.logger.log("Restarting WALA before committing suicide") self.context.logger.log("Current executable path: " + sys.executable) self.context.logger.log("Current executable arguments: " + " ".join(sys.argv)) # Kill any other daemons that are blocked and would be executed after this process commits # suicide self.command_executor.Execute('systemctl restart atd') os.chdir('/') with open("/delete-lock.sh", "w") as f: f.write("rm -f /var/lib/azure_disk_encryption_config/daemon_lock_file.lck\n") self.command_executor.Execute('at -f /delete-lock.sh now + 1 minutes', True) self.command_executor.Execute('at -f /restart-wala.sh now + 2 minutes', True) self.command_executor.ExecuteInBash('pkill -f .*ForLinux.*handle.py.*daemon.*', True) if int(victim) == 1: self.context.logger.log("Skipping init") continue self.command_executor.Execute('kill -9 {0}'.format(victim)) self.command_executor.Execute('telinit u', True) sleep(10) proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="systemctl list-units", raise_exception_on_failure=True, communicator=proc_comm) for line in proc_comm.stdout.split('\n'): match = re.search(r'\s(\S*?\.mount)', line) if match: mount = match.groups()[0] self.command_executor.Execute('systemctl stop {0}'.format(mount)) continue sleep(10) if self.command_executor.Execute('mountpoint /var') == 0: self.command_executor.Execute('umount /var', True) sleep(3) if self.command_executor.Execute('mountpoint ' + mountpoint) == 0: self.command_executor.Execute('umount ' + mountpoint, True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/rhel_72_lvm/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from SelinuxState import * from StripdownState import * from UnmountOldrootState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1404/Ubuntu1404EncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class Ubuntu1404EncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='split_root_partition', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_stripdown', 'source': 'prereq', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_split_root_partition', 'source': 'unmount_oldroot', 'dest': 'split_root_partition', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'split_root_partition', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(Ubuntu1404EncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(Ubuntu1404EncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(Ubuntu1404EncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'split_root_partition': SplitRootPartitionState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=Ubuntu1404EncryptionStateMachine.states, transitions=Ubuntu1404EncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="mount", raise_exception_on_failure=True, communicator=proc_comm) if '/dev/mapper/osencrypt' in proc_comm.stdout: self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_split_root_partition() self.log_machine_state() self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1404/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from Ubuntu1404EncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1404/encryptpatches/ubuntu_1404_initramfs.patch ================================================ diff -Naur initramfs-tools.orig/hooks/cryptroot initramfs-tools/hooks/cryptroot --- initramfs-tools.orig/hooks/cryptroot 2016-10-27 20:26:44.920064500 +0000 +++ initramfs-tools/hooks/cryptroot 2016-10-27 20:27:15.922161900 +0000 @@ -511,10 +511,7 @@ # Find the root and resume device(s) if [ -r /etc/crypttab ]; then - rootdev=$(get_root_device) - if [ -z "$rootdev" ]; then - echo "cryptsetup: WARNING: could not determine root device from /etc/fstab" >&2 - fi + rootdev="osencrypt" resumedevs=$(get_resume_devices) fi diff -Naur initramfs-tools.orig/scripts/local-top/cryptroot initramfs-tools/scripts/local-top/cryptroot --- initramfs-tools.orig/scripts/local-top/cryptroot 2016-10-27 20:26:44.916064500 +0000 +++ initramfs-tools/scripts/local-top/cryptroot 2016-10-27 20:28:01.621309300 +0000 @@ -229,11 +229,7 @@ if [ "$cryptdiscard" = "yes" ]; then cryptcreate="$cryptcreate --allow-discards" fi - if /sbin/cryptsetup isLuks $cryptsource >/dev/null 2>&1; then - cryptcreate="$cryptcreate luksOpen $cryptsource $crypttarget" - else - cryptcreate="$cryptcreate -c $cryptcipher -s $cryptsize -h $crypthash create $crypttarget $cryptsource" - fi + cryptcreate="$cryptcreate luksOpen $cryptsource $crypttarget --header=/boot/luks/osluksheader" cryptremove="/sbin/cryptsetup remove $crypttarget" NEWROOT="/dev/mapper/$crypttarget" ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1404/encryptscripts/azure_crypt_key.sh ================================================ #!/bin/sh MountPoint=/tmp-keydisk-mount KeyFileName=LinuxPassPhraseFileName echo "Trying to get the key from disks ..." >&2 mkdir -p $MountPoint modprobe nls_utf8 >/dev/null 2>&1 modprobe nls_cp437 >/dev/null 2>&1 modprobe vfat >/dev/null 2>&1 sleep 2 OPENED=0 cd /sys/block for DEV in sd*; do echo "> Trying device: $DEV ..." >&2 mount -t vfat -r /dev/${DEV}1 $MountPoint >&2 2>&1 if [ -f $MountPoint/$KeyFileName ]; then cat $MountPoint/$KeyFileName && echo "Success loading keyfile!" >&2 umount $MountPoint 2>/dev/null OPENED=1 break fi umount $MountPoint 2>/dev/null done if [ $OPENED -eq 0 ]; then echo "FAILED to find suitable passphrase file ..." >&2 echo -n "Try to enter your password: " >&2 read -r A 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Restarting udev") self.command_executor.Execute('service udev restart') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 self.command_executor.Execute('e2fsck -yf {0}'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1404/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from StripdownState import * from UnmountOldrootState import * from SplitRootPartitionState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1604/Ubuntu1604EncryptionStateMachine.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) maindir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(maindir) transitionsdir = os.path.abspath(os.path.join(scriptdir, '../../transitions')) sys.path.append(transitionsdir) from oscrypto import * from encryptstates import * from Common import * from CommandExecutor import * from DiskUtil import * from transitions import * class Ubuntu1604EncryptionStateMachine(OSEncryptionStateMachine): states = [ State(name='uninitialized'), State(name='prereq', on_enter='on_enter_state'), State(name='stripdown', on_enter='on_enter_state'), State(name='unmount_oldroot', on_enter='on_enter_state'), State(name='split_root_partition', on_enter='on_enter_state'), State(name='encrypt_block_device', on_enter='on_enter_state'), State(name='patch_boot_system', on_enter='on_enter_state'), State(name='completed'), ] transitions = [ { 'trigger': 'skip_encryption', 'source': 'uninitialized', 'dest': 'completed' }, { 'trigger': 'enter_prereq', 'source': 'uninitialized', 'dest': 'prereq' }, { 'trigger': 'enter_stripdown', 'source': 'prereq', 'dest': 'stripdown', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_unmount_oldroot', 'source': 'stripdown', 'dest': 'unmount_oldroot', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'retry_unmount_oldroot', 'source': 'unmount_oldroot', 'dest': 'unmount_oldroot', 'before': 'on_enter_state' }, { 'trigger': 'enter_split_root_partition', 'source': 'unmount_oldroot', 'dest': 'split_root_partition', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_encrypt_block_device', 'source': 'split_root_partition', 'dest': 'encrypt_block_device', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'enter_patch_boot_system', 'source': 'encrypt_block_device', 'dest': 'patch_boot_system', 'before': 'on_enter_state', 'conditions': 'should_exit_previous_state' }, { 'trigger': 'stop_machine', 'source': 'patch_boot_system', 'dest': 'completed', 'conditions': 'should_exit_previous_state' }, ] def on_enter_state(self): super(Ubuntu1604EncryptionStateMachine, self).on_enter_state() def should_exit_previous_state(self): # when this is called, self.state is still the "source" state in the transition return super(Ubuntu1604EncryptionStateMachine, self).should_exit_previous_state() def __init__(self, hutil, distro_patcher, logger, encryption_environment): super(Ubuntu1604EncryptionStateMachine, self).__init__(hutil, distro_patcher, logger, encryption_environment) self.state_objs = { 'prereq': PrereqState(self.context), 'stripdown': StripdownState(self.context), 'unmount_oldroot': UnmountOldrootState(self.context), 'split_root_partition': SplitRootPartitionState(self.context), 'encrypt_block_device': EncryptBlockDeviceState(self.context), 'patch_boot_system': PatchBootSystemState(self.context), } self.state_machine = Machine(model=self, states=Ubuntu1604EncryptionStateMachine.states, transitions=Ubuntu1604EncryptionStateMachine.transitions, initial='uninitialized') def start_encryption(self): proc_comm = ProcessCommunicator() self.command_executor.Execute(command_to_execute="mount", raise_exception_on_failure=True, communicator=proc_comm) if '/dev/mapper/osencrypt' in proc_comm.stdout: self.logger.log("OS volume is already encrypted") self.skip_encryption() self.log_machine_state() return self.log_machine_state() self.enter_prereq() self.log_machine_state() self.enter_stripdown() self.log_machine_state() oldroot_unmounted_successfully = False attempt = 1 while not oldroot_unmounted_successfully: self.logger.log("Attempt #{0} to unmount /oldroot".format(attempt)) try: if attempt == 1: self.enter_unmount_oldroot() elif attempt > 10: raise Exception("Could not unmount /oldroot in 10 attempts") else: self.retry_unmount_oldroot() self.log_machine_state() except Exception as e: message = "Attempt #{0} to unmount /oldroot failed with error: {1}, stack trace: {2}".format(attempt, e, traceback.format_exc()) self.logger.log(msg=message) self.hutil.do_status_report(operation='EnableEncryptionOSVolume', status=CommonVariables.extension_error_status, status_code=str(CommonVariables.unmount_oldroot_error), message=message) sleep(10) if attempt > 10: raise Exception(message) else: oldroot_unmounted_successfully = True finally: attempt += 1 self.enter_split_root_partition() self.log_machine_state() self.enter_encrypt_block_device() self.log_machine_state() self.enter_patch_boot_system() self.log_machine_state() self.stop_machine() self.log_machine_state() self._reboot() ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1604/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from Ubuntu1604EncryptionStateMachine import * ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1604/encryptpatches/ubuntu_1604_initramfs.patch ================================================ diff -Naur hooks.orig/cryptroot hooks/cryptroot --- hooks.orig/cryptroot 2016-07-24 04:00:35.707468106 +0000 +++ hooks/cryptroot 2016-07-24 04:00:58.251574341 +0000 @@ -521,14 +521,14 @@ mkdir -p "$DESTDIR/conf/conf.d/cryptheader" fi - #if [ -e "$CONFDIR/conf.d/cryptheader/$CRYPTHEADER" ]; then - # copy_exec "$CONFDIR/conf.d/cryptheader/$CRYPTHEADER" /conf/conf.d/cryptheader >&2 - #elif [ -e "$CRYPTHEADER" ]; then - # copy_exec "$CRYPTHEADER" /conf/conf.d/cryptheader >&2 - #else - # echo "cryptsetup: WARNING: failed to find LUKS header $CRYPTHEADER" >&2 - # continue - #fi + if [ -e "$CONFDIR/conf.d/cryptheader/$CRYPTHEADER" ]; then + copy_exec "$CONFDIR/conf.d/cryptheader/$CRYPTHEADER" /conf/conf.d/cryptheader >&2 + elif [ -e "$CRYPTHEADER" ]; then + copy_exec "$CRYPTHEADER" /conf/conf.d/cryptheader >&2 + else + echo "cryptsetup: WARNING: failed to find LUKS header $CRYPTHEADER" >&2 + continue + fi fi @@ -627,6 +627,9 @@ if [ -z "$rootdevs" ]; then echo "cryptsetup: WARNING: could not determine root device from /etc/fstab" >&2 fi + if ! echo "$rootdevs" | grep -q "osencrypt"; then + rootdevs="$rootdevs osencrypt" + fi usrdevs=$(get_fs_devices /usr) resumedevs=$(get_resume_devices) initramfsdevs=$(get_initramfs_devices) ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1604/encryptscripts/azure_crypt_key.sh ================================================ #!/bin/sh MountPoint=/tmp-keydisk-mount KeyFileName=LinuxPassPhraseFileName echo "Trying to get the key from disks ..." >&2 mkdir -p $MountPoint modprobe nls_utf8 >/dev/null 2>&1 modprobe nls_cp437 >/dev/null 2>&1 modprobe vfat >/dev/null 2>&1 sleep 2 OPENED=0 cd /sys/block for DEV in sd*; do echo "> Trying device: $DEV ..." >&2 mount -t vfat -r /dev/${DEV}1 $MountPoint >&2 2>&1 if [ -f $MountPoint/$KeyFileName ]; then cat $MountPoint/$KeyFileName && echo "Success loading keyfile!" >&2 umount $MountPoint 2>/dev/null OPENED=1 break fi umount $MountPoint 2>/dev/null done if [ $OPENED -eq 0 ]; then echo "FAILED to find suitable passphrase file ..." >&2 echo -n "Try to enter your password: " >&2 read -r A 10: raise Exception("Block device {0} did not appear in 10 restart attempts".format(self.rootfs_block_device)) self.context.logger.log("Restarting systemd-udevd") self.command_executor.Execute('systemctl restart systemd-udevd') self.context.logger.log("Restarting systemd-timesyncd") self.command_executor.Execute('systemctl restart systemd-timesyncd') self.context.logger.log("Restarting systemd-networkd") self.command_executor.Execute('systemctl restart systemd-networkd') sleep(10) if self.command_executor.ExecuteInBash('[ -b {0} ]'.format(self.rootfs_block_device), False) == 0: break attempt += 1 self.command_executor.Execute('e2fsck -yf {0}'.format(self.rootfs_block_device), True) def should_exit(self): self.context.logger.log("Verifying if machine should exit unmount_oldroot state") if os.path.exists('/oldroot/bin'): self.context.logger.log("/oldroot was not unmounted") return False return super(UnmountOldrootState, self).should_exit() ================================================ FILE: VMEncryption/main/oscrypto/ubuntu_1604/encryptstates/__init__.py ================================================ # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.7+ # import inspect import os import sys import traceback from time import sleep scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) oscryptodir = os.path.abspath(os.path.join(scriptdir, '../../')) sys.path.append(oscryptodir) from OSEncryptionState import * from PrereqState import * from StripdownState import * from UnmountOldrootState import * from SplitRootPartitionState import * from EncryptBlockDeviceState import * from PatchBootSystemState import * ================================================ FILE: VMEncryption/main/patch/AbstractPatching.py ================================================ #!/usr/bin/python # # AbstractPatching is the base patching class of all the linux distros # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess class AbstractPatching(object): """ AbstractPatching defines a skeleton neccesary for a concrete Patching class. """ def __init__(self, distro_info): self.distro_info = distro_info self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' self.kernel_version = platform.release() def install_adal(self): pass def install_extras(self): pass def update_prereq(self): pass ================================================ FILE: VMEncryption/main/patch/SuSEPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * from CommandExecutor import * class SuSEPatching(AbstractPatching): def __init__(self, logger, distro_info): super(SuSEPatching, self).__init__(distro_info) self.distro_info = distro_info self.command_executor = CommandExecutor(logger) if distro_info[1] == "11": self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cryptsetup_path = '/sbin/cryptsetup' self.cat_path = '/bin/cat' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' self.blockdev_path = '/sbin/blockdev' else: self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_adal(self): if self.distro_info[1] == "11": try: self.command_executor.ExecuteInBash('pip list | grep -F adal', raise_exception_on_failure=True) except: raise Exception('SLES 11 environment is missing python-pip and adal') else: self.command_executor.Execute('zypper --gpg-auto-import-keys install -l -y python-pip') self.command_executor.Execute('python -m pip install --upgrade pip') self.command_executor.Execute('python -m pip install adal') def install_extras(self): packages = ['cryptsetup', 'lsscsi'] cmd = " ".join((['zypper', 'install', '-l', '-y'] + packages)) self.command_executor.Execute(cmd) def update_prereq(self): pass ================================================ FILE: VMEncryption/main/patch/UbuntuPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * from CommandExecutor import * class UbuntuPatching(AbstractPatching): def __init__(self, logger, distro_info): super(UbuntuPatching, self).__init__(distro_info) self.logger = logger self.command_executor = CommandExecutor(logger) self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' self.touch_path = '/usr/bin/touch' def install_adal(self): return_code = self.command_executor.Execute('apt-get install -y python-pip') # If install fails, try running apt-get update and then try install again if return_code != 0: self.logger.log('python-pip installation failed. Retrying installation after running update') return_code = self.command_executor.Execute('apt-get -o Acquire::ForceIPv4=true -y update', timeout=30) # Fail early if apt-get update times out. if return_code == -9: msg = "Command: apt-get -o Acquire::ForceIPv4=true -y update timed out. Make sure apt-get is configured correctly." raise Exception(msg) self.command_executor.Execute('apt-get install -y python-pip') self.command_executor.Execute('python -m pip install --upgrade pip') self.command_executor.Execute('python -m pip install --upgrade setuptools') self.command_executor.Execute('python -m pip install adal') def install_extras(self): """ install the sg_dd because the default dd do not support the sparse write """ packages = ['at', 'cryptsetup-bin', 'lsscsi', 'python-parted', 'python-six', 'procps', 'psmisc'] cmd = " ".join(['apt-get', 'install', '-y'] + packages) return_code = self.command_executor.Execute(cmd) # If install fails, try running apt-get update and then try install again if return_code != 0: self.logger.log('prereq packages installation failed. Retrying installation after running update') return_code = self.command_executor.Execute('apt-get -o Acquire::ForceIPv4=true -y update') # Fail early if apt-get update times out. if return_code == -9: msg = "Command: apt-get -o Acquire::ForceIPv4=true -y update timed out. Make sure apt-get is configured correctly." raise Exception(msg) cmd = " ".join(['apt-get', 'install', '-y'] + packages) self.command_executor.Execute(cmd) def update_prereq(self): pass ================================================ FILE: VMEncryption/main/patch/__init__.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import re import platform from UbuntuPatching import UbuntuPatching from debianPatching import debianPatching from redhatPatching import redhatPatching from centosPatching import centosPatching from SuSEPatching import SuSEPatching from oraclePatching import oraclePatching # Define the function in case waagent(<2.0.4) doesn't have DistInfo() def DistInfo(): if 'FreeBSD' in platform.system(): release = re.sub('\-.*\Z', '', str(platform.release())) distinfo = ['FreeBSD', release] return distinfo if 'linux_distribution' in dir(platform): distinfo = list(platform.linux_distribution(full_distribution_name=0)) # remove trailing whitespace in distro name distinfo[0] = distinfo[0].strip() return distinfo else: return platform.dist() def GetDistroPatcher(logger): """ Return DistroPatcher object. NOTE: Logging is not initialized at this point. """ dist_info = DistInfo() if 'Linux' in platform.system(): Distro = dist_info[0] else: # I know this is not Linux! if 'FreeBSD' in platform.system(): Distro = platform.system() Distro = Distro.strip('"') Distro = Distro.strip(' ') patching_class_name = Distro + 'Patching' if not globals().has_key(patching_class_name): logger.log('{0} is not a supported distribution.'.format(Distro)) return None patchingInstance = globals()[patching_class_name](logger, dist_info) return patchingInstance ================================================ FILE: VMEncryption/main/patch/centosPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import os.path import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from redhatPatching import redhatPatching from Common import * from CommandExecutor import * class centosPatching(redhatPatching): def __init__(self, logger, distro_info): super(centosPatching, self).__init__(logger, distro_info) self.logger = logger self.command_executor = CommandExecutor(logger) if distro_info[1] in ["6.9", "6.8", "6.7", "6.6", "6.5"]: self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_adal(self): # epel-release and python-pip >= version 8.1 are adal prerequisites # https://github.com/AzureAD/azure-activedirectory-library-for-python/ self.command_executor.Execute("yum install -y epel-release") self.command_executor.Execute("yum install -y python-pip") self.command_executor.Execute("python -m pip install --upgrade pip") self.command_executor.Execute("python -m pip install adal") def install_extras(self): packages = ['cryptsetup', 'lsscsi', 'psmisc', 'cryptsetup-reencrypt', 'lvm2', 'uuid', 'at', 'patch', 'procps-ng', 'util-linux', 'pyparted'] if self.distro_info[1].startswith("6."): packages.add('python-six') packages.remove('cryptsetup') packages.remove('procps-ng') packages.remove('util-linux') if self.command_executor.Execute("rpm -q " + " ".join(packages)): self.command_executor.Execute("yum install -y " + " ".join(packages)) def update_prereq(self): if (self.distro_info[1].startswith('7.')): dracut_repack_needed = False if os.path.exists("/lib/dracut/modules.d/91lvm/"): # If 90lvm already exists 91lvm will cause problems, so remove it. if os.path.exists("/lib/dracut/modules.d/90lvm/"): shutil.rmtree("/lib/dracut/modules.d/91lvm/") else: os.rename("/lib/dracut/modules.d/91lvm/","/lib/dracut/modules.d/90lvm/") dracut_repack_needed = True if redhatPatching.is_old_patching_system(): redhatPatching.remove_old_patching_system(self.logger, self.command_executor) dracut_repack_needed = True if os.path.exists("/lib/dracut/modules.d/91ade/"): shutil.rmtree("/lib/dracut/modules.d/91ade/") dracut_repack_needed = True if os.path.exists("/dev/mapper/osencrypt"): #TODO: only do this if needed (if code and existing module are different) redhatPatching.add_91_ade_dracut_module(self.command_executor) dracut_repack_needed = True if dracut_repack_needed: self.command_executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) ================================================ FILE: VMEncryption/main/patch/debianPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from AbstractPatching import AbstractPatching from Common import * class debianPatching(AbstractPatching): def __init__(self, logger, distro_info): super(debianPatching, self).__init__(distro_info) self.logger = logger self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' def install_adal(self): pass def install_extras(self): pass def update_prereq(self): pass ================================================ FILE: VMEncryption/main/patch/oraclePatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess from redhatPatching import redhatPatching from Common import * class oraclePatching(redhatPatching): def __init__(self,logger,distro_info): super(oraclePatching,self).__init__(logger,distro_info) self.logger = logger if(distro_info is not None and len(distro_info) > 0 and distro_info[1].startswith("6.")): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.umount_path = '/usr/bin/umount' def install_adal(self): pass def install_extras(self): common_extras = ['cryptsetup','lsscsi'] for extra in common_extras: self.logger.log("installation for " + extra + 'result is ' + str(subprocess.call(['yum', 'install','-y', extra]))) def update_prereq(self): pass ================================================ FILE: VMEncryption/main/patch/redhatPatching.py ================================================ #!/usr/bin/python # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Requires Python 2.4+ import os import os.path import sys import imp import base64 import re import json import platform import shutil import time import traceback import datetime import subprocess import inspect from AbstractPatching import AbstractPatching from Common import * from CommandExecutor import * class redhatPatching(AbstractPatching): def __init__(self, logger, distro_info): super(redhatPatching, self).__init__(distro_info) self.logger = logger self.command_executor = CommandExecutor(logger) self.distro_info = distro_info if distro_info[1].startswith("6."): self.base64_path = '/usr/bin/base64' self.bash_path = '/bin/bash' self.blkid_path = '/sbin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/sbin/cryptsetup' self.dd_path = '/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/bin/mkdir' self.mount_path = '/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.touch_path = '/bin/touch' self.umount_path = '/bin/umount' else: self.base64_path = '/usr/bin/base64' self.bash_path = '/usr/bin/bash' self.blkid_path = '/usr/bin/blkid' self.cat_path = '/bin/cat' self.cryptsetup_path = '/usr/sbin/cryptsetup' self.dd_path = '/usr/bin/dd' self.e2fsck_path = '/sbin/e2fsck' self.echo_path = '/usr/bin/echo' self.getenforce_path = '/usr/sbin/getenforce' self.setenforce_path = '/usr/sbin/setenforce' self.lsblk_path = '/usr/bin/lsblk' self.lsscsi_path = '/usr/bin/lsscsi' self.mkdir_path = '/usr/bin/mkdir' self.mount_path = '/usr/bin/mount' self.openssl_path = '/usr/bin/openssl' self.resize2fs_path = '/sbin/resize2fs' self.touch_path = '/usr/bin/touch' self.umount_path = '/usr/bin/umount' def install_adal(self): # On RHEL, RHSCL pip >= version 8.1 is the supported mechanism to install adal # https://access.redhat.com/solutions/1519803 self.command_executor.Execute('yum install -y python27-python-pip') self.command_executor.Execute('scl enable python27 "pip install --upgrade pip"') self.command_executor.Execute('scl enable python27 "pip install adal"') def install_extras(self): packages = ['cryptsetup', 'lsscsi', 'psmisc', 'cryptsetup-reencrypt', 'lvm2', 'uuid', 'at', 'patch', 'procps-ng', 'util-linux'] if self.distro_info[1].startswith("6."): packages.remove('cryptsetup') packages.remove('procps-ng') packages.remove('util-linux') if self.command_executor.Execute("rpm -q " + " ".join(packages)): self.command_executor.Execute("yum install -y " + " ".join(packages)) def update_prereq(self): if (self.distro_info[1].startswith('7.')): dracut_repack_needed = False if os.path.exists("/lib/dracut/modules.d/91lvm/"): # If 90lvm already exists 91lvm will cause problems, so remove it. if os.path.exists("/lib/dracut/modules.d/90lvm/"): shutil.rmtree("/lib/dracut/modules.d/91lvm/") else: os.rename("/lib/dracut/modules.d/91lvm/","/lib/dracut/modules.d/90lvm/") dracut_repack_needed = True if redhatPatching.is_old_patching_system(): redhatPatching.remove_old_patching_system(self.logger, self.command_executor) dracut_repack_needed = True if os.path.exists("/lib/dracut/modules.d/91ade/"): shutil.rmtree("/lib/dracut/modules.d/91ade/") dracut_repack_needed = True if os.path.exists("/dev/mapper/osencrypt"): #TODO: only do this if needed (if code and existing module are different) redhatPatching.add_91_ade_dracut_module(self.command_executor) dracut_repack_needed = True if dracut_repack_needed: self.command_executor.ExecuteInBash("/usr/sbin/dracut -f -v --kver `grubby --default-kernel | sed 's|/boot/vmlinuz-||g'`", True) @staticmethod def is_old_patching_system(): # Execute unpatching commands only if all the three patch files are present. if os.path.exists("/lib/dracut/modules.d/90crypt/cryptroot-ask.sh.orig"): if os.path.exists("/lib/dracut/modules.d/90crypt/module-setup.sh.orig"): if os.path.exists("/lib/dracut/modules.d/90crypt/parse-crypt.sh.orig"): return True return False @staticmethod def append_contents_to_file(contents, path): with open(path, 'a') as f: f.write(contents) @staticmethod def add_91_ade_dracut_module(command_executor): scriptdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) ademoduledir = os.path.join(scriptdir, '../oscrypto/91ade') dracutmodulesdir = '/lib/dracut/modules.d' udevaderulepath = os.path.join(dracutmodulesdir, '91ade/50-udev-ade.rules') proc_comm = ProcessCommunicator() command_executor.Execute('cp -r {0} /lib/dracut/modules.d/'.format(ademoduledir), True) crypt_cmd = "cryptsetup status osencrypt | grep device:" command_executor.ExecuteInBash(crypt_cmd, communicator=proc_comm, suppress_logging=True) matches = re.findall(r'device:(.*)', proc_comm.stdout) if not matches: raise Exception("Could not find device in cryptsetup output") root_device = matches[0].strip() udevadm_cmd = "udevadm info --attribute-walk --name={0}".format(root_device) command_executor.Execute(command_to_execute=udevadm_cmd, raise_exception_on_failure=True, communicator=proc_comm) matches = re.findall(r'ATTR{partition}=="(.*)"', proc_comm.stdout) if not matches: raise Exception("Could not parse ATTR{partition} from udevadm info") partition = matches[0] sed_cmd = 'sed -i.bak s/ENCRYPTED_DISK_PARTITION/{0}/ "{1}"'.format(partition, udevaderulepath) command_executor.Execute(command_to_execute=sed_cmd, raise_exception_on_failure=True) sed_grub_cmd = "sed -i.bak '/osencrypt-locked/d' /etc/crypttab" command_executor.Execute(command_to_execute=sed_grub_cmd, raise_exception_on_failure=True) @staticmethod def remove_old_patching_system(logger, command_executor): logger.log("Removing patches and recreating initrd image") command_executor.Execute('mv /lib/dracut/modules.d/90crypt/cryptroot-ask.sh.orig /lib/dracut/modules.d/90crypt/cryptroot-ask.sh', False) command_executor.Execute('mv /lib/dracut/modules.d/90crypt/module-setup.sh.orig /lib/dracut/modules.d/90crypt/module-setup.sh', False) command_executor.Execute('mv /lib/dracut/modules.d/90crypt/parse-crypt.sh.orig /lib/dracut/modules.d/90crypt/parse-crypt.sh', False) sed_grub_cmd = "sed -i.bak '/rd.luks.uuid=osencrypt/d' /etc/default/grub" command_executor.Execute(sed_grub_cmd) redhatPatching.append_contents_to_file('\nGRUB_CMDLINE_LINUX+=" rd.debug"\n', '/etc/default/grub') command_executor.Execute('grub2-mkconfig -o /boot/grub2/grub.cfg', True) ================================================ FILE: VMEncryption/references ================================================ Utils/ ================================================ FILE: VMEncryption/requirements.txt ================================================ funcsigs==1.0.2 mock==2.0.0 pbr==4.3.0 six==1.11.0 ================================================ FILE: VMEncryption/setup.py ================================================ #!/usr/bin/env python # # VM Backup extension # # Copyright 2015 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # To build: # python setup.py sdist # # To install: # python setup.py install # # To register (only needed once): # python setup.py register # # To upload: # python setup.py sdist upload import codecs import json import os import subprocess from distutils.core import setup from zipfile import ZipFile from shutil import copy2 from main.Common import CommonVariables packages_array = [] main_folder = 'main' main_entry = main_folder + '/handle.py' packages_array.append(main_folder) patch_folder = main_folder + '/patch' packages_array.append(patch_folder) oscrypto_folder = main_folder + '/oscrypto' packages_array.append(oscrypto_folder) packages_array.append(oscrypto_folder + '/91ade') packages_array.append(oscrypto_folder + '/rhel_72_lvm') packages_array.append(oscrypto_folder + '/rhel_72_lvm/encryptstates') packages_array.append(oscrypto_folder + '/rhel_72') packages_array.append(oscrypto_folder + '/rhel_72/encryptstates') packages_array.append(oscrypto_folder + '/rhel_68') packages_array.append(oscrypto_folder + '/rhel_68/encryptstates') packages_array.append(oscrypto_folder + '/centos_68') packages_array.append(oscrypto_folder + '/centos_68/encryptstates') packages_array.append(oscrypto_folder + '/ubuntu_1604') packages_array.append(oscrypto_folder + '/ubuntu_1604/encryptstates') packages_array.append(oscrypto_folder + '/ubuntu_1404') packages_array.append(oscrypto_folder + '/ubuntu_1404/encryptstates') transitions_folder = 'transitions/transitions' packages_array.append(transitions_folder) """ copy the dependency to the local """ """ copy the utils lib to local """ target_utils_path = main_folder + '/' + CommonVariables.utils_path_name packages_array.append(target_utils_path) """ generate the HandlerManifest.json file. """ manifest_obj = [{ "name": CommonVariables.extension_name, "version": "1.0", "handlerManifest": { "installCommand": "extension_shim.sh -c {0} --install".format(main_entry), "uninstallCommand": "extension_shim.sh -c {0} --uninstall".format(main_entry), "updateCommand": "extension_shim.sh -c {0} --update".format(main_entry), "enableCommand": "extension_shim.sh -c {0} --enable".format(main_entry), "disableCommand": "extension_shim.sh -c {0} --disable".format(main_entry), "rebootAfterInstall": False, "reportHeartbeat": False } }] manifest_str = json.dumps(manifest_obj, sort_keys = True, indent = 4) manifest_file = open("HandlerManifest.json", "w") manifest_file.write(manifest_str) manifest_file.close() """ generate the extension xml file """ extension_xml_file_content = """ Microsoft.Azure.Security %s %s VmRole %s true https://azure.microsoft.com/en-us/support/legal/ https://azure.microsoft.com/en-us/support/legal/ https://github.com/Azure/azure-linux-extensions true Linux Microsoft """ % (CommonVariables.extension_type, CommonVariables.extension_version, CommonVariables.extension_label, CommonVariables.extension_description) extension_xml_file = open('manifest.xml', 'w') extension_xml_file.write(extension_xml_file_content) extension_xml_file.close() """ setup script, to package the files up """ setup(name = CommonVariables.extension_name, version = CommonVariables.extension_version, description=CommonVariables.extension_description, license='Apache License 2.0', author='Microsoft Corporation', author_email='andliu@microsoft.com', url='https://github.com/Azure/azure-linux-extensions', classifiers = ['Development Status :: 5 - Production/Stable', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'License :: OSI Approved :: Apache Software License'], packages = packages_array) """ unzip the package files and re-package it. """ target_zip_file_location = './dist/' target_folder_name = CommonVariables.extension_name + '-' + str(CommonVariables.extension_version) target_zip_file_path = target_zip_file_location + target_folder_name + '.zip' target_zip_file = ZipFile(target_zip_file_path) target_zip_file.extractall(target_zip_file_location) def dos2unix(src): args = ["dos2unix", src] devnull = open(os.devnull, 'w') child = subprocess.Popen(args, stdout=devnull, stderr=devnull) print('dos2unix %s ' % (src)) child.wait() def remove_utf8_bom(src): print('removing utf-8 bom from %s ' % (src)) contents = None with open(src, "r+b") as fp: bincontents = fp.read() if bincontents[:len(codecs.BOM_UTF8)] == codecs.BOM_UTF8: contents = bincontents.decode('utf-8-sig') elif bincontents[:3] == '\xef\x00\x00': contents = bincontents[3:].decode('utf-8') else: contents = bincontents.decode('utf8') with open(src, "wb") as fp: fp.write(contents.encode('utf-8')) def zip(src, dst): zf = ZipFile("%s" % (dst), "w") abs_src = os.path.abspath(src) for dirname, subdirs, files in os.walk(src): for filename in files: absname = os.path.abspath(os.path.join(dirname, filename)) dos2unix(absname) remove_utf8_bom(absname) arcname = absname[len(abs_src) + 1:] print('zipping %s as %s' % (os.path.join(dirname, filename), arcname)) zf.write(absname, arcname) zf.close() final_folder_path = target_zip_file_location + target_folder_name # Manually add SupportedOS.json file as setup seems to only copy py file copy2(main_folder+'/SupportedOS.json', final_folder_path+'/'+main_folder ) zip(final_folder_path, target_zip_file_path) ================================================ FILE: VMEncryption/test/__init__.py ================================================ # # Copyright 2018 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: VMEncryption/test/console_logger.py ================================================ #!/usr/bin/env python # # ********************************************************* # Copyright (c) Microsoft. All rights reserved. # # Apache 2.0 License # # You may obtain a copy of the License at # http:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ********************************************************* import os import string import json class HandlerContext: def __init__(self, name): self._name = name self._version = '0.0' return class ConsoleLogger(object): def __init__(self): self.current_process_id = os.getpid() self._context = HandlerContext("test") self._context._config = json.loads('{"runtimeSettings": [{"handlerSettings": {"publicSettings": {"EncryptionOperation": "EnableEncryptionFormatAll"}}}]}') def log(self, msg, level='Info'): """ simple logging mechanism to print to stdout """ log_msg = "{0}: [{1}] {2}".format(self.current_process_id, level, msg) print(log_msg) def error(self, msg): log(msg,'Error') ================================================ FILE: VMEncryption/test/test_check_util.py ================================================ import unittest import mock from main.check_util import CheckUtil from main.Common import CommonVariables from StringIO import StringIO from console_logger import ConsoleLogger class MockDistroPatcher: def __init__(self, name, version, kernel): self.distro_info = [None] * 2 self.distro_info[0] = name self.distro_info[1] = version self.kernel_version = kernel class TestCheckUtil(unittest.TestCase): """ unit tests for functions in the check_util module """ def setUp(self): self.logger = ConsoleLogger() self.cutil = CheckUtil(self.logger) def get_mock_filestream(self, somestring): stream = StringIO() stream.write(somestring) stream.seek(0) return stream @mock.patch('os.path.isfile', return_value=False) @mock.patch('os.path.isdir', return_value=False) def test_appcompat(self, os_path_isdir, os_path_isfile): self.assertFalse(self.cutil.is_app_compat_issue_detected()) @mock.patch('os.popen') def test_memory(self, os_popen): output = "8000000" os_popen.return_value = self.get_mock_filestream(output) self.assertFalse(self.cutil.is_insufficient_memory()) @mock.patch('os.popen') def test_memory_low_memory(self, os_popen): output = "6000000" os_popen.return_value = self.get_mock_filestream(output) self.assertTrue(self.cutil.is_insufficient_memory()) def test_is_kv_url(self): dns_suffix_list = ["vault.azure.net", "vault.azure.cn", "vault.usgovcloudapi.net", "vault.microsoftazure.de"] for dns_suffix in dns_suffix_list: self.cutil.check_kv_url("https://testkv." + dns_suffix + "/", "") self.cutil.check_kv_url("https://test-kv2." + dns_suffix + "/", "") self.cutil.check_kv_url("https://test-kv2." + dns_suffix + ":443/", "") self.cutil.check_kv_url("https://test-kv2." + dns_suffix + ":443/keys/kekname/kekversion", "") self.assertRaises(Exception, self.cutil.check_kv_url, "http://testkv." + dns_suffix + "/", "") # self.assertRaises(Exception, self.cutil.check_kv_url, "https://https://testkv." + dns_suffix + "/", "") # self.assertRaises(Exception, self.cutil.check_kv_url, "https://testkv.testkv." + dns_suffix + "/", "") # self.assertRaises(Exception, self.cutil.check_kv_url, "https://testkv.vault.azure.com/", "") self.assertRaises(Exception, self.cutil.check_kv_url, "https://", "") def test_validate_volume_type(self): self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "DATA"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "ALL"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "all"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "Os"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "OS"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "os"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "Data"}) self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: "data"}) for vt in CommonVariables.SupportedVolumeTypes: self.cutil.validate_volume_type({CommonVariables.VolumeTypeKey: vt}) self.assertRaises(Exception, self.cutil.validate_volume_type, {CommonVariables.VolumeTypeKey: "NON-OS"}) self.assertRaises(Exception, self.cutil.validate_volume_type, {CommonVariables.VolumeTypeKey: ""}) self.assertRaises(Exception, self.cutil.validate_volume_type, {CommonVariables.VolumeTypeKey: "123"}) @mock.patch('main.check_util.CheckUtil.validate_memory_os_encryption') @mock.patch('main.CommandExecutor.CommandExecutor.Execute', return_value=0) def test_fatal_checks(self, mock_exec, mock_validate_memory): mock_distro_patcher = MockDistroPatcher('Ubuntu', '14.04', '4.15') self.cutil.precheck_for_fatal_failures({ CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.QueryEncryptionStatus }, { "os": "NotEncrypted" }, mock_distro_patcher) self.cutil.precheck_for_fatal_failures({ CommonVariables.VolumeTypeKey: "DATA", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.DisableEncryption }, { "os": "NotEncrypted" }, mock_distro_patcher) self.cutil.precheck_for_fatal_failures({ CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption }, { "os": "NotEncrypted" }, mock_distro_patcher) self.cutil.precheck_for_fatal_failures({ CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormat }, { "os": "NotEncrypted" }, mock_distro_patcher) self.cutil.precheck_for_fatal_failures({ CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-256', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "NotEncrypted" }, mock_distro_patcher) self.assertRaises(Exception, self.cutil.precheck_for_fatal_failures, {}) self.assertRaises(Exception, self.cutil.precheck_for_fatal_failures, { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-256', CommonVariables.AADClientIDKey: "INVALIDKEY", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, mock_distro_patcher) self.assertRaises(Exception, self.cutil.precheck_for_fatal_failures, { CommonVariables.VolumeTypeKey: "123", CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption }, { "os": "NotEncrypted" }, mock_distro_patcher) self.assertRaises(Exception, self.cutil.precheck_for_fatal_failures, { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-25600', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "NotEncrypted" }, mock_distro_patcher) mock_distro_patcher = MockDistroPatcher('Ubuntu', '14.04', '4.4') self.assertRaises(Exception, self.cutil.precheck_for_fatal_failures, { CommonVariables.VolumeTypeKey: "ALL" }, { "os": "NotEncrypted" }, mock_distro_patcher) def test_mount_scheme(self): proc_mounts_output = """ sysfs /sys sysfs rw,nosuid,nodev,noexec,relatime 0 0 proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0 udev /dev devtmpfs rw,relatime,size=4070564k,nr_inodes=1017641,mode=755 0 0 devpts /dev/pts devpts rw,nosuid,noexec,relatime,gid=5,mode=620,ptmxmode=000 0 0 tmpfs /run tmpfs rw,nosuid,noexec,relatime,size=815720k,mode=755 0 0 /dev/sda1 / ext4 rw,relatime,discard,data=ordered 0 0 none /sys/fs/cgroup tmpfs rw,relatime,size=4k,mode=755 0 0 none /sys/fs/fuse/connections fusectl rw,relatime 0 0 none /sys/kernel/debug debugfs rw,relatime 0 0 none /sys/kernel/security securityfs rw,relatime 0 0 none /run/lock tmpfs rw,nosuid,nodev,noexec,relatime,size=5120k 0 0 none /run/shm tmpfs rw,nosuid,nodev,relatime 0 0 none /run/user tmpfs rw,nosuid,nodev,noexec,relatime,size=102400k,mode=755 0 0 none /sys/fs/pstore pstore rw,relatime 0 0 systemd /sys/fs/cgroup/systemd cgroup rw,nosuid,nodev,noexec,relatime,name=systemd 0 0 /dev/mapper/fee16d98-9c18-4e7d-af70-afd7f3dfb2d9 /mnt/resource ext4 rw,relatime,data=ordered 0 0 /dev/mapper/vg0-lv0 /data ext4 rw,relatime,discard,data=ordered 0 0 """ with mock.patch("__builtin__.open", mock.mock_open(read_data=proc_mounts_output)): self.assertFalse(self.cutil.is_unsupported_mount_scheme()) # Skip LVM OS validation when OS volume is not being targeted def test_skip_lvm_os_check_if_data_only_enable(self): # skip lvm detection if data only self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "DATA", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption}) def test_skip_lvm_os_check_if_data_only_ef(self): # skip lvm detection if data only self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "DATA", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormat}) def test_skip_lvm_os_check_if_data_only_efa(self): # skip lvm detection if data only self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "DATA", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll}) def test_skip_lvm_os_check_if_data_only_disable(self): # skip lvm detection if data only self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "DATA", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.DisableEncryption}) def test_skip_lvm_os_check_if_query(self): # skip lvm detection if query status operation is invoked without volume type self.cutil.validate_lvm_os({CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.QueryEncryptionStatus}) def test_skip_lvm_no_encryption_operation(self): # skip lvm detection if no encryption operation self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "ALL"}) def test_skip_lvm_no_volume_type(self): # skip lvm detection if no volume type specified self.cutil.validate_lvm_os({CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll}) @mock.patch("os.system", return_value=-1) def test_no_lvm_no_config(self, os_system): # simulate no LVM OS, no config self.cutil.validate_lvm_os({}) @mock.patch("os.system", return_value=0) def test_lvm_no_config(self, os_system): # simulate valid LVM OS, no config self.cutil.validate_lvm_os({}) @mock.patch("os.system", side_effect=[0, -1]) def test_invalid_lvm_no_config(self, os_system): # simulate invalid LVM naming scheme, but no config setting to encrypt OS self.cutil.validate_lvm_os({}) @mock.patch("os.system", return_value=-1) def test_lvm_os_lvm_absent(self, os_system): # using patched return value of -1, simulate no LVM OS self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "ALL", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption}) @mock.patch("os.system", return_value=0) def test_lvm_os_valid(self, os_system): # simulate a valid LVM OS and a valid naming scheme by always returning 0 self.cutil.validate_lvm_os({CommonVariables.VolumeTypeKey: "ALL", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption}) @mock.patch("os.system", side_effect=[0, -1]) def test_lvm_os_lv_missing_expected_name(self, os_system): # using patched side effects, first simulate LVM OS present, then simulate not finding the expected LV name self.assertRaises(Exception, self.cutil.validate_lvm_os, {CommonVariables.VolumeTypeKey: "ALL", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption}) @mock.patch("main.CommandExecutor.CommandExecutor.Execute", return_value=0) def test_vfat(self, os_system): # simulate call to modprobe vfat that succeeds and returns cleanly from execute self.cutil.validate_vfat() @mock.patch("main.CommandExecutor.CommandExecutor.Execute", side_effect=Exception("Test")) def test_no_vfat(self, os_system): # simulate call to modprobe vfat that fails and raises exception from execute self.assertRaises(Exception, self.cutil.validate_vfat) def test_validate_aad(self): # positive tests test_settings = {} test_settings[CommonVariables.AADClientIDKey] = "00000000-0000-0000-0000-000000000000" test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryption self.cutil.validate_aad(test_settings) test_settings = {} test_settings[CommonVariables.AADClientIDKey] = "00000000-0000-aaaa-0000-000000000000" test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryptionFormat self.cutil.validate_aad(test_settings) test_settings = {} test_settings[CommonVariables.AADClientIDKey] = "00000000-0000-AAAA-0000-000000000000" test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryptionFormatAll self.cutil.validate_aad(test_settings) test_settings = {} test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.DisableEncryption self.cutil.validate_aad(test_settings) test_settings = {} test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.QueryEncryptionStatus self.cutil.validate_aad(test_settings) # negative tests # settings file that does not include AAD client ID field test_settings = {} test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryption self.assertRaises(Exception, self.cutil.validate_aad, test_settings) # invalid characters in the client ID test_settings = {} test_settings[CommonVariables.AADClientIDKey] = "BORKED" test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryption self.assertRaises(Exception, self.cutil.validate_aad, test_settings) # empty string test_settings = {} test_settings[CommonVariables.AADClientIDKey] = "" test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryption self.assertRaises(Exception, self.cutil.validate_aad, test_settings) # unicode left and right double quotes (simulating a copy-paste error) test_settings = {} test_settings[CommonVariables.AADClientIDKey] = u'\u201c' + "00000000-0000-0000-0000-000000000000" + u'\u201d' test_settings[CommonVariables.EncryptionEncryptionOperationKey] = CommonVariables.EnableEncryption self.assertRaises(Exception, self.cutil.validate_aad, test_settings) @mock.patch('os.popen') def test_minimum_memory(self, os_popen): output = "6000000" os_popen.return_value = self.get_mock_filestream(output) self.assertRaises(Exception, self.cutil.validate_memory_os_encryption, { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-25600', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "NotEncrypted" }) try: self.cutil.validate_memory_os_encryption( { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-25600', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "Encrypted" }) except Exception: self.fail("validate_memory_os_encryption threw unexpected exception\nException message was:\n" + str(e)) try: output = "8000000" os_popen.return_value = self.get_mock_filestream(output) self.cutil.validate_memory_os_encryption( { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-25600', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "Encrypted" }) except Exception: self.fail("validate_memory_os_encryption threw unexpected exception\nException message was:\n" + str(e)) try: output = "8000000" os_popen.return_value = self.get_mock_filestream(output) self.cutil.validate_memory_os_encryption( { CommonVariables.VolumeTypeKey: "ALL", CommonVariables.KeyVaultURLKey: "https://vaultname.vault.azure.net/", CommonVariables.KeyEncryptionKeyURLKey: "https://vaultname.vault.azure.net/keys/keyname/ver", CommonVariables.KeyEncryptionAlgorithmKey: 'rsa-OAEP-25600', CommonVariables.AADClientIDKey: "00000000-0000-0000-0000-000000000000", CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll }, { "os": "NotEncrypted" }) except Exception: self.fail("validate_memory_os_encryption threw unexpected exception\nException message was:\n" + str(e)) def test_supported_os(self): # test exception is raised for Ubuntu 14.04 kernel version self.assertRaises(Exception, self.cutil.is_supported_os, { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('Ubuntu', '14.04', '4.4'), {"os" : "NotEncrypted"}) # test exception is not raised for Ubuntu 14.04 kernel version 4.15 try: self.cutil.is_supported_os( { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('Ubuntu', '14.04', '4.15'), {"os" : "NotEncrypted"}) except Exception as e: self.fail("is_unsupported_os threw unexpected exception.\nException message was:\n" + str(e)) # test exception is not raised for already encrypted OS volume try: self.cutil.is_supported_os( { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('Ubuntu', '14.04', '4.4'), {"os" : "Encrypted"}) except Exception as e: self.fail("is_unsupported_os threw unexpected exception.\nException message was:\n" + str(e)) # test exception is raised for unsupported OS self.assertRaises(Exception, self.cutil.is_supported_os, { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('Ubuntu', '12.04', ''), {"os" : "NotEncrypted"}) self.assertRaises(Exception, self.cutil.is_supported_os, { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('redhat', '6.7', ''), {"os" : "NotEncrypted"}) self.assertRaises(Exception, self.cutil.is_supported_os, { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('centos', '7.9', ''), {"os" : "NotEncrypted"}) # test exception is not raised for supported OS try: self.cutil.is_supported_os( { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('Ubuntu', '18.04', ''), {"os" : "NotEncrypted"}) except Exception as e: self.fail("is_unsupported_os threw unexpected exception.\nException message was:\n" + str(e)) try: self.cutil.is_supported_os( { CommonVariables.VolumeTypeKey: "ALL" }, MockDistroPatcher('centos', '7.2.1511', ''), {"os" : "NotEncrypted"}) except Exception as e: self.fail("is_unsupported_os threw unexpected exception.\nException message was:\n" + str(e)) # test exception is not raised for DATA volume try: self.cutil.is_supported_os( { CommonVariables.VolumeTypeKey: "DATA" }, MockDistroPatcher('SuSE', '12.4', ''), {"os" : "NotEncrypted"}) except Exception as e: self.fail("is_unsupported_os threw unexpected exception.\nException message was:\n" + str(e)) ================================================ FILE: VMEncryption/test/test_command_executor.py ================================================ import unittest from main.CommandExecutor import CommandExecutor from console_logger import ConsoleLogger class TestCommandExecutor(unittest.TestCase): """ unit tests for functions in the CommandExecutor module """ def setUp(self): self.logger = ConsoleLogger() self.cmd_executor = CommandExecutor(self.logger) def test_command_timeout(self): return_code = self.cmd_executor.Execute('sleep 15', timeout=10) self.assertEqual(return_code, -9, msg="The command didn't timeout as expected") def test_command_no_timeout(self): return_code = self.cmd_executor.Execute('sleep 5', timeout=10) self.assertEqual(return_code, 0, msg="The command should have completed successfully") ================================================ FILE: VMEncryption/test/test_disk_util.py ================================================ import unittest import mock from main.Common import CryptItem from main.EncryptionEnvironment import EncryptionEnvironment from main.DiskUtil import DiskUtil from console_logger import ConsoleLogger from test_utils import MockDistroPatcher class TestDiskUtil(unittest.TestCase): """ unit tests for functions in the CryptMountConfig module """ def setUp(self): self.logger = ConsoleLogger() self.disk_util = DiskUtil(None, MockDistroPatcher('Ubuntu', '14.04', '4.15'), self.logger, EncryptionEnvironment(None, self.logger)) def _mock_open_with_read_data_dict(self, open_mock, read_data_dict): open_mock.content_dict = read_data_dict def _open_side_effect(filename, mode, *args, **kwargs): read_data = open_mock.content_dict.get(filename) mock_obj = mock.mock_open(read_data=read_data) handle = mock_obj.return_value def write_handle(data, *args, **kwargs): if 'a' in mode: open_mock.content_dict[filename] += data else: open_mock.content_dict[filename] = data def write_lines_handle(data, *args, **kwargs): if 'a' in mode: open_mock.content_dict[filename] += "".join(data) else: open_mock.content_dict[filename] = "".join(data) handle.write.side_effect = write_handle handle.writelines.side_effect = write_lines_handle return handle open_mock.side_effect = _open_side_effect def _create_expected_crypt_item(self, mapper_name=None, dev_path=None, uses_cleartext_key=None, luks_header_path=None, mount_point=None, file_system=None, current_luks_slot=None): crypt_item = CryptItem() crypt_item.mapper_name = mapper_name crypt_item.dev_path = dev_path crypt_item.uses_cleartext_key = uses_cleartext_key crypt_item.luks_header_path = luks_header_path crypt_item.mount_point = mount_point crypt_item.file_system = file_system crypt_item.current_luks_slot = current_luks_slot return crypt_item def test_parse_crypttab_line(self): # empty line line = "" crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(None, crypt_item) # line with not enough entries line = "mapper_name dev_path" crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(None, crypt_item) # commented out line line = "# mapper_name dev_path" crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(None, crypt_item) # An unfamiliar key_file_path implies that we shouln't be processing this crypttab line line = "mapper_name /dev/dev_path /non_managed_key_file_path" crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(None, crypt_item) # a bare bones crypttab line line = "mapper_name /dev/dev_path /mnt/azure_bek_disk/LinuxPassPhraseFileName luks" expected_crypt_item = self._create_expected_crypt_item(mapper_name="mapper_name", dev_path="/dev/dev_path") crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(str(expected_crypt_item), str(crypt_item)) # a line that implies a cleartext key line = "mapper_name /dev/dev_path /var/lib/azure_disk_encryption_config/cleartext_key_mapper_name luks" expected_crypt_item = self._create_expected_crypt_item(mapper_name="mapper_name", dev_path="/dev/dev_path", uses_cleartext_key=True) crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(str(expected_crypt_item), str(crypt_item)) # a line that implies a luks header line = "mapper_name /dev/dev_path /var/lib/azure_disk_encryption_config/cleartext_key_mapper_name luks,header=headerfile" expected_crypt_item = self._create_expected_crypt_item(mapper_name="mapper_name", dev_path="/dev/dev_path", uses_cleartext_key=True, luks_header_path="headerfile") crypt_item = self.disk_util.parse_crypttab_line(line) self.assertEquals(str(expected_crypt_item), str(crypt_item)) @mock.patch('__builtin__.open') @mock.patch('os.path.exists', return_value=True) def test_should_use_azure_crypt_mount(self, exists_mock, open_mock): # if the acm file exists and has only a root disk acm_contents = """ osencrypt /dev/dev_path None / ext4 False 0 """ mock.mock_open(open_mock, acm_contents) self.assertFalse(self.disk_util.should_use_azure_crypt_mount()) # if the acm file exists and has a data disk acm_contents = """ mapper_name /dev/dev_path None /mnt/point ext4 False 0 mapper_name2 /dev/dev_path2 None /mnt/point2 ext4 False 0 """ mock.mock_open(open_mock, acm_contents) self.assertTrue(self.disk_util.should_use_azure_crypt_mount()) # empty file mock.mock_open(open_mock, "") self.assertFalse(self.disk_util.should_use_azure_crypt_mount()) # no file exists_mock.return_value = False open_mock.reset_mock() self.assertFalse(self.disk_util.should_use_azure_crypt_mount()) open_mock.assert_not_called() @mock.patch('os.path.exists', return_value=True) @mock.patch('main.DiskUtil.ProcessCommunicator') @mock.patch('main.CommandExecutor.CommandExecutor', autospec=True) @mock.patch('__builtin__.open') @mock.patch('main.DiskUtil.DiskUtil.should_use_azure_crypt_mount') @mock.patch('main.DiskUtil.DiskUtil.get_encryption_status') @mock.patch('main.DiskUtil.DiskUtil.get_mount_items') def test_get_crypt_items(self, get_mount_items_mock, get_enc_status_mock, use_acm_mock, open_mock, ce_mock, pc_mock, exists_mock): self.disk_util.command_executor = ce_mock use_acm_mock.return_value = True # Use the Azure_Crypt_Mount file get_enc_status_mock.return_value = "{\"os\" : \"Encrypted\"}" acm_contents = """ osencrypt /dev/dev_path None / ext4 True 0 """ mock.mock_open(open_mock, acm_contents) crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([self._create_expected_crypt_item(mapper_name="osencrypt", dev_path="/dev/dev_path", uses_cleartext_key=True, mount_point="/", file_system="ext4", current_luks_slot=0)], crypt_items) ce_mock.ExecuteInBash.return_value = 0 # The grep on cryptsetup succeeds pc_mock.return_value.stdout = "osencrypt /dev/dev_path" # The grep find this line in there mock.mock_open(open_mock, "") # No content in the azure crypt mount file get_mount_items_mock.return_value = [{"src": "/dev/mapper/osencrypt", "dest": "/", "fs": "ext4"}] exists_mock.return_value = False # No luksheader file found crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([self._create_expected_crypt_item(mapper_name="osencrypt", dev_path="/dev/dev_path", mount_point="/", file_system="ext4")], crypt_items) use_acm_mock.return_value = False # Now, use the /etc/crypttab file exists_mock.return_value = True # Crypttab file found self._mock_open_with_read_data_dict(open_mock, {"/etc/fstab": "/dev/mapper/osencrypt / ext4 defaults,nofail 0 0", "/etc/crypttab": "osencrypt /dev/sda1 /mnt/azure_bek_disk/LinuxPassPhraseFileName luks,discard"}) crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([self._create_expected_crypt_item(mapper_name="osencrypt", dev_path="/dev/sda1", file_system=None, mount_point="/")], crypt_items) # if there was no crypttab entry for osencrypt exists_mock.side_effect = [True, False] # Crypttab file found but luksheader not found self._mock_open_with_read_data_dict(open_mock, {"/etc/fstab": "/dev/mapper/osencrypt / ext4 defaults,nofail 0 0", "/etc/crypttab": ""}) ce_mock.ExecuteInBash.return_value = 0 # The grep on cryptsetup succeeds pc_mock.return_value.stdout = "osencrypt /dev/sda1" # The grep find this line in there crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([self._create_expected_crypt_item(mapper_name="osencrypt", dev_path="/dev/sda1", file_system="ext4", mount_point="/")], crypt_items) exists_mock.side_effect = None # Crypttab file found exists_mock.return_value = True # Crypttab file found get_enc_status_mock.return_value = "{\"os\" : \"NotEncrypted\"}" self._mock_open_with_read_data_dict(open_mock, {"/etc/fstab": "", "/etc/crypttab": ""}) crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([], crypt_items) self._mock_open_with_read_data_dict(open_mock, {"/etc/fstab": "/dev/mapper/encrypteddatadisk /mnt/datadisk auto defaults,nofail 0 0", "/etc/crypttab": "encrypteddatadisk /dev/disk/azure/scsi1/lun0 /someplainfile luks"}) crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([], crypt_items) self._mock_open_with_read_data_dict(open_mock, {"/etc/fstab": "/dev/mapper/encrypteddatadisk /mnt/datadisk auto defaults,nofail 0 0", "/etc/crypttab": "encrypteddatadisk /dev/disk/azure/scsi1/lun0 /mnt/azure_bek_disk/LinuxPassPhraseFileName luks,discard,header=/headerfile"}) crypt_items = self.disk_util.get_crypt_items() self.assertListEqual([self._create_expected_crypt_item(mapper_name="encrypteddatadisk", dev_path="/dev/disk/azure/scsi1/lun0", file_system=None, luks_header_path="/headerfile", mount_point="/mnt/datadisk")], crypt_items) @mock.patch('shutil.copy2', return_value=True) @mock.patch('os.rename', return_value=True) @mock.patch('os.path.exists', return_value=True) @mock.patch('__builtin__.open') @mock.patch('main.DiskUtil.DiskUtil.should_use_azure_crypt_mount', return_value=True) @mock.patch('main.DiskUtil.DiskUtil.get_encryption_status') def test_migrate_crypt_items(self, get_enc_status_mock, use_acm_mock, open_mock, exists_mock, rename_mock, shutil_mock): def rename_side_effect(name1, name2): use_acm_mock.return_value = False return True rename_mock.side_effect = rename_side_effect get_enc_status_mock.return_value = "{\"os\" : \"NotEncrypted\"}" # Test 1: migrate an entry self._mock_open_with_read_data_dict(open_mock, {"/var/lib/azure_disk_encryption_config/azure_crypt_mount": "mapper_name /dev/dev_path None /mnt/point ext4 False 0", "/etc/fstab.azure.backup": "/dev/dev_path /mnt/point ext4 defaults,nofail 0 0", "/etc/fstab": "", "/etc/crypttab": ""}) self.disk_util.migrate_crypt_items("/test_passphrase_path") self.assertTrue("/dev/mapper/mapper_name /mnt/point" in open_mock.content_dict["/etc/fstab"]) self.assertTrue("mapper_name /dev/dev_path /test_passphrase_path" in open_mock.content_dict["/etc/crypttab"]) # Test 2: migrate no entry use_acm_mock.return_value = True self._mock_open_with_read_data_dict(open_mock, {"/var/lib/azure_disk_encryption_config/azure_crypt_mount": "", "/etc/fstab.azure.backup": "", "/etc/fstab": "", "/etc/crypttab": ""}) self.disk_util.migrate_crypt_items("/test_passphrase_path") self.assertTrue("" == open_mock.content_dict["/etc/fstab"].strip()) self.assertTrue("" == open_mock.content_dict["/etc/crypttab"].strip()) # Test 3: skip migrating the OS entry use_acm_mock.return_value = True self._mock_open_with_read_data_dict(open_mock, {"/var/lib/azure_disk_encryption_config/azure_crypt_mount": "osencrypt /dev/dev_path None / ext4 False 0", "/etc/fstab.azure.backup": "/dev/dev_path / ext4 defaults 0 0", "/etc/fstab": "", "/etc/crypttab": ""}) self.disk_util.migrate_crypt_items("/test_passphrase_path") self.assertTrue("" == open_mock.content_dict["/etc/fstab"].strip()) self.assertTrue("" == open_mock.content_dict["/etc/crypttab"].strip()) # Test 4: migrate many entries use_acm_mock.return_value = True acm_contents = """ mapper_name /dev/dev_path None /mnt/point ext4 False 0 mapper_name2 /dev/dev_path2 None /mnt/point2 ext4 False 0 """ fstab_backup_contents = """ /dev/dev_path /mnt/point ext4 defaults,nofail 0 0 /dev/dev_path2 /mnt/point2 ext4 defaults,nofail 0 0 """ self._mock_open_with_read_data_dict(open_mock, {"/var/lib/azure_disk_encryption_config/azure_crypt_mount": acm_contents, "/etc/fstab.azure.backup": fstab_backup_contents, "/etc/fstab": "", "/etc/crypttab": ""}) self.disk_util.migrate_crypt_items("/test_passphrase_path") self.assertTrue("/dev/mapper/mapper_name /mnt/point ext4 defaults,nofail 0 0\n" in open_mock.content_dict["/etc/fstab"]) self.assertTrue("\n/dev/mapper/mapper_name2 /mnt/point2 ext4 defaults,nofail 0 0" in open_mock.content_dict["/etc/fstab"]) self.assertTrue("\nmapper_name /dev/dev_path /test_passphrase_path" in open_mock.content_dict["/etc/crypttab"]) self.assertTrue("\nmapper_name2 /dev/dev_path2 /test_passphrase_path" in open_mock.content_dict["/etc/crypttab"]) ================================================ FILE: VMEncryption/test/test_handler_util.py ================================================ #!/usr/bin/env python # # ********************************************************* # Copyright (c) Microsoft. All rights reserved. # # Apache 2.0 License # # You may obtain a copy of the License at # http:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ********************************************************* """ Unit tests for the HandlerUtil module """ import unittest import os import console_logger import patch import glob from Utils import HandlerUtil from tempfile import mkstemp class TestHandlerUtil(unittest.TestCase): def setUp(self): self.logger = console_logger.ConsoleLogger() self.distro_patcher = patch.GetDistroPatcher(self.logger) self.hutil = HandlerUtil.HandlerUtility(self.logger.log, self.logger.error, "AzureDiskEncryptionForLinux") self.hutil.patching = self.distro_patcher # invoke unit test from within main for setup (to avoid having to change dependencies) # then move cwd to parent to emulate calling convention of guest agent if os.getcwd().endswith('main'): os.chdir(os.path.dirname(os.getcwd())) else: self.logger.log(os.getcwd()) def test_parse_config_sp(self): # test 0.1 sp config syntax test_sp = '{"runtimeSettings": [{"handlerSettings": {"protectedSettings": null, "publicSettings": {"VolumeType": "OS", "KeyEncryptionKeyURL": "", "KekVaultResourceId": "", "KeyEncryptionAlgorithm": "RSA-OAEP", "KeyVaultURL": "https://testkv.vault.azure.net/", "KeyVaultResourceId": "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testrg/providers/Microsoft.KeyVault/vaults/testkv", "EncryptionOperation": "EnableEncryption"}, "protectedSettingsCertThumbprint": null} }]}' self.assertIsNotNone(self.hutil._parse_config(test_sp)) def test_parse_config_dp_enable(self): # test 1.1 dp config syntax test_dp = '{"runtimeSettings": [{"handlerSettings": {"protectedSettings": "MIIB8AYJKoZIhvcNAQcDoIIB4TCCAd0CAQAxggFpMIIBZQIBADBNMDkxNzA1BgoJkiaJk/IsZAEZFidXaW5kb3dzIEF6dXJlIENSUCBDZXJ0aWZpY2F0ZSBHZW5lcmF0b3ICEG5XyHr6J9qxRLVe/RzaobIwDQYJKoZIhvcNAQEBBQAEggEASDt5QPp0i8R408Ho2JNs0gEAKmjo17qg7Wk+Ihy5I3krCHY4pGGzWAXafvZ3Y1rLh7m/k1+uwK94o3taI27NEvz4YAbCkzLdgiNZx3yZdn5KkRzSbakztnf1a/MTEXY0dYjEjK9ZN5H5XiS8OLhpXaOgayaz1ZFS5MnOufBFXWuL2qeYK/txfBXIJujBHru80b+YahwnHU7/nislCslYVxENn9Jp9VpKGEcCeDFo/KKi0BTbpkxPj3OScNcsPuSRUP9xgT/b96bARJKeLjrxHQa398gzp291OlDYTr4sKBPqGNk8wER0aSpOm6igE857YAc0tShKQhGI14jcEHUu2jBrBgkqhkiG9w0BBwEwFAYIKoZIhvcNAwcECPpjFE+mGCN7gEj0rWo00NbAoQ6VhMnzdnZ3MnKOCjdWr/NTOdTgHMXU732rfDL89dMHLmUnBHq4SyTqIAi0M6sPEJ38anxx/msIQl15/w8qmL8=", "publicSettings": {"AADClientID": "00000000-0000-0000-0000-000000000000", "VolumeType": "DATA", "KeyEncryptionKeyURL": "https://testkv.vault.azure.net/keys/adelpackek/a022ed2b1eba4befb0dc9dc07bf33578", "KeyEncryptionAlgorithm": "RSA-OAEP", "KeyVaultURL": "https://testkv.vault.azure.net", "SequenceVersion": "eec80fc4-e0a2-434e-9007-974a150c3407", "AADClientCertThumbprint": null, "EncryptionOperation": "EnableEncryption"}, "protectedSettingsCertThumbprint": "45E4EC25EECAD03EC81F8177CEF16CD3CAF6297A"} }]}' self.assertIsNotNone(self.hutil._parse_config(test_dp)) def test_parse_config_dp_query(self): test_dpq = '{"runtimeSettings": [{"handlerSettings": {"protectedSettings": "MIIBsAYJKoZIhvcNAQcDoIIBoTCCAZ0CAQAxggFpMIIBZQIBADBNMDkxNzA1BgoJkiaJk/IsZAEZFidXaW5kb3dzIEF6dXJlIENSUCBDZXJ0aWZpY2F0ZSBHZW5lcmF0b3ICEG5XyHr6J9qxRLVe/RzaobIwDQYJKoZIhvcNAQEBBQAEggEAE92LccPctK0h52F+WOjKPWat5O3nxjQpsLKquMtwiKsc5BMot8dLEAE1h7V7SJJ8kiGRLS232mwvVbOA+nOs3l1lCUNDnckbzvvuu/rgz+if1sHvYIn0Xd/kXHSMNm9loh9lTLagGblEFxGupcBcsAEptcjL0f7zUG1NrlnKPVDGceOw7I3dQK6X8rPrMHJ8m6wiHpTvjpa/xmG0mrVyOGjJv7cEDnJ0A8pvRHUrZGGuqi/4WeGPGDKQzmVc6O5oGFfke3bAOd9GJxFWhLwZ1lb1XrKNImVDT2vnWWFiy2lKDwUvKSdqRpaqRNr6f7tZcDWiB+v+vZ6V4GC33kT0mDArBgkqhkiG9w0BBwEwFAYIKoZIhvcNAwcECJeXx+KpPZqdgAgiUsAz+Acz6A==", "publicSettings": {"SequenceVersion": "3838692e-4827-4175-8286-86828d199f85", "EncryptionOperation": "QueryEncryptionStatus"}, "protectedSettingsCertThumbprint": "45E4EC25EECAD03EC81F8177CEF16CD3CAF6297A"} }]}' self.assertIsNotNone(self.hutil._parse_config(test_dpq)) def test_do_parse_context_install(self): self.assertIsNotNone(self.hutil.do_parse_context('Install')) def test_do_parse_context_enable(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) def test_do_parse_context_enable_encryption(self): self.assertIsNotNone(self.hutil.do_parse_context('EnableEncryption')) def test_do_parse_context_disable(self): self.assertIsNotNone(self.hutil.do_parse_context('Disable')) def test_do_parse_context_disable_nosettings(self): # simulate missing settings file by adding .bak extension config_dir = os.path.join(os.getcwd(), 'config') settings_files = glob.glob(os.path.join(config_dir, '*.settings')) for settings_file in settings_files: os.rename(settings_file, settings_file + '.bak') try: # test to simulate disable when no settings are available self.hutil.do_parse_context('Disable') self.hutil.archive_old_configs() finally: # restore settings files back to original name for settings_file in settings_files: os.rename(settings_file + '.bak', settings_file) def test_do_parse_context_uninstall(self): self.assertIsNotNone(self.hutil.do_parse_context('Uninstall')) def test_do_parse_context_disable_encryption(self): self.assertIsNotNone(self.hutil.do_parse_context('DisableEncryption')) def test_do_parse_context_update_encryption_settings(self): self.assertIsNotNone(self.hutil.do_parse_context('UpdateEncryptionSettings')) def test_do_parse_context_update(self): self.assertIsNotNone(self.hutil.do_parse_context('Update')) def test_do_parse_context_executing(self): self.assertIsNotNone(self.hutil.do_parse_context('Executing')) def test_try_parse_context(self): self.assertIsNotNone(self.hutil.try_parse_context()) def test_is_valid_nonquery_true(self): nonquery_settings = '{"runtimeSettings": [{"handlerSettings": {"protectedSettingsCertThumbprint": null, "publicSettings": {"VolumeType": "DATA", "KekVaultResourceId": "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testrg/providers/Microsoft.KeyVault/vaults/testkv", "EncryptionOperation": "EnableEncryption", "KeyEncryptionAlgorithm": "RSA-OAEP", "KeyEncryptionKeyURL": "https://testkv.vault.azure.net/keys/testkek/805291e00028474a87e302ce507ed049", "KeyVaultURL": "https://testkv.vault.azure.net", "KeyVaultResourceId": "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testrg/providers/Microsoft.KeyVault/vaults/testkv", "SequenceVersion": "c8608bb5-df18-43a7-9f0e-dbe09a57fd0b"}, "protectedSettings": null} }]}' # use a temp file path for this test, not the config folder tmp_fd, tmp_path = mkstemp(text=True) with os.fdopen(tmp_fd,'w') as f: f.write(nonquery_settings) test_result = self.hutil.is_valid_nonquery(tmp_path) os.remove(tmp_path) # assert true, this is not a QueryEncryptionStatus operation self.assertTrue(test_result) def test_is_valid_nonquery_false(self): query_settings = '{"runtimeSettings": [{"handlerSettings": {"protectedSettings": "MIIBsAYJKoZIhvcNAQcDoIIBoTCCAZ0CAQAxggFpMIIBZQIBADBNMDkxNzA1BgoJkiaJk/IsZAEZFidXaW5kb3dzIEF6dXJlIENSUCBDZXJ0aWZpY2F0ZSBHZW5lcmF0b3ICEG5XyHr6J9qxRLVe/RzaobIwDQYJKoZIhvcNAQEBBQAEggEAE92LccPctK0h52F+WOjKPWat5O3nxjQpsLKquMtwiKsc5BMot8dLEAE1h7V7SJJ8kiGRLS232mwvVbOA+nOs3l1lCUNDnckbzvvuu/rgz+if1sHvYIn0Xd/kXHSMNm9loh9lTLagGblEFxGupcBcsAEptcjL0f7zUG1NrlnKPVDGceOw7I3dQK6X8rPrMHJ8m6wiHpTvjpa/xmG0mrVyOGjJv7cEDnJ0A8pvRHUrZGGuqi/4WeGPGDKQzmVc6O5oGFfke3bAOd9GJxFWhLwZ1lb1XrKNImVDT2vnWWFiy2lKDwUvKSdqRpaqRNr6f7tZcDWiB+v+vZ6V4GC33kT0mDArBgkqhkiG9w0BBwEwFAYIKoZIhvcNAwcECJeXx+KpPZqdgAgiUsAz+Acz6A==", "publicSettings": {"SequenceVersion": "3838692e-4827-4175-8286-86828d199f85", "EncryptionOperation": "QueryEncryptionStatus"}, "protectedSettingsCertThumbprint": "45E4EC25EECAD03EC81F8177CEF16CD3CAF6297A"} }]}' # use a temp file path for this test, not the config folder tmp_fd, tmp_path = mkstemp(text=True) with os.fdopen(tmp_fd,'w') as f: f.write(query_settings) test_result = self.hutil.is_valid_nonquery(tmp_path) os.remove(tmp_path) # assert false, this is a QueryEncryptionStatus operation self.assertFalse(test_result) def test_get_last_nonquery_config_path(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) self.assertIsNotNone(self.hutil.get_last_nonquery_config_path()) def test_get_last_config(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) self.assertIsNotNone(self.hutil.get_last_config(nonquery=False)) def test_get_last_nonquery_config(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) config = self.hutil.get_last_config(nonquery=True) self.assertIsNotNone(config) def test_get_handler_env(self): self.assertIsNotNone(self.hutil.get_handler_env()) def test_archive_old_configs(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) self.hutil.archive_old_configs() def test_archive_old_configs_overwrite_lnq(self): self.assertIsNotNone(self.hutil.do_parse_context('Enable')) # this test ensures that the archive_old_configs method will properly overwrite an existing lnq.settings file # with any newer non query settings file that might exist on the system # stuff a bogus lnq.settings file in the archived settings folder # and backdate the file time to older than current settings prior to testing tmpstr = 'test_archive_old_configs_overwrite_lnq : the contents of this file are intended to be overwritten and never used' if not os.path.exists(self.hutil.config_archive_folder): os.makedirs(self.hutil.config_archive_folder) dest = os.path.join(self.hutil.config_archive_folder, 'lnq.settings') with open(dest,'w') as f: f.write(tmpstr) # backdate os.utime(dest,(0,0)) # run the test self.hutil.archive_old_configs() # ensure the new lnq.settings file in the folder has the expected content ================================================ FILE: VMEncryption/test/test_resource_disk_util.py ================================================ #!/usr/bin/env python # # ********************************************************* # Copyright (c) Microsoft. All rights reserved. # # Apache 2.0 License # # You may obtain a copy of the License at # http:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ********************************************************* """ Unit tests for the ResourceDiskUtil module """ import unittest import mock from main.ResourceDiskUtil import ResourceDiskUtil from main.DiskUtil import DiskUtil from main.Common import CommonVariables from console_logger import ConsoleLogger class TestResourceDiskUtil(unittest.TestCase): def setUp(self): self.logger = ConsoleLogger() self.mock_disk_util = mock.create_autospec(DiskUtil) self.mock_passhprase_filename = "mock_passphrase_filename" mock_public_settings = {} self.resource_disk = ResourceDiskUtil(self.logger, self.mock_disk_util, self.mock_passhprase_filename, mock_public_settings, ["ubuntu", "16"]) def _test_resource_disk_partition_dependant_method(self, method, mock_partition_exists, mock_execute): """ A lot of methods have a common pattern [ if (partition_exists()): return execute_something() else return False ] This is a generic method which accepts the mock objects and the method pointer and tests the method. NOTE: make sure its a fresh instance of the mocked Executor (mock_execute) """ # case 1: partition doesn't exist mock_partition_exists.return_value = False self.assertEqual(method(), False) mock_execute.assert_not_called() # case 2: partition exists but call fails mock_partition_exists.return_value = True mock_execute.return_value = -1 # simulate that the internal execute call failed. self.assertEqual(method(), False) # case 3: partition exists and call succeeds mock_partition_exists.return_value = True mock_execute.return_value = CommonVariables.process_success # simulate that the internal execute call succeeded self.assertEqual(method(), True) @mock.patch('main.CommandExecutor.CommandExecutor.Execute') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._resource_disk_partition_exists') def test_is_luks_device(self, mock_partition_exists, mock_execute): self._test_resource_disk_partition_dependant_method(self.resource_disk._is_luks_device, mock_partition_exists, mock_execute) @mock.patch('main.CommandExecutor.CommandExecutor.Execute') def test_configure_waagent(self, mock_execute): mock_execute.side_effect = [-1, 0, 0] self.assertEqual(self.resource_disk._configure_waagent(), False) mock_execute.assert_called_once() self.assertEqual(self.resource_disk._configure_waagent(), True) def test_is_plain_mounted(self): self.resource_disk.disk_util.get_mount_items.return_value = [] self.assertEqual(self.resource_disk._is_plain_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/dm-0", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_plain_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/mapper/something", "dest": "/mnt/"}] self.assertEqual(self.resource_disk._is_plain_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/sdcx", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_plain_mounted(), True) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/sdb2", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_plain_mounted(), True) def test_is_crypt_mounted(self): self.resource_disk.disk_util.get_mount_items.return_value = [] self.assertEqual(self.resource_disk._is_crypt_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/dm-0", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_crypt_mounted(), True) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/mapper/something", "dest": "/mnt/"}] self.assertEqual(self.resource_disk._is_crypt_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/mapper/something", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_crypt_mounted(), True) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/sdcx", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_crypt_mounted(), False) self.resource_disk.disk_util.get_mount_items.return_value = [{"src": "/dev/sdb2", "dest": "/mnt/resource"}] self.assertEqual(self.resource_disk._is_crypt_mounted(), False) @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil.add_resource_disk_to_crypttab') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._resource_disk_partition_exists') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._is_luks_device') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._is_crypt_mounted') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._is_plain_mounted') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._mount_resource_disk') def test_try_remount(self, mock_mount, mock_plain_mounted, mock_crypt_mounted, mock_is_luks, mock_partition_exists, mock_add_rd_to_crypttab): # Case 1, when there is a passphrase and the resource disk is not already encrypted and mounted. mock_partition_exists.return_value = True mock_is_luks.return_value = False mock_crypt_mounted.return_value = False mock_mount.return_value = True self.resource_disk.passphrase_filename = self.mock_passhprase_filename self.assertEqual(self.resource_disk.try_remount(), False) mock_mount.assert_not_called() mock_add_rd_to_crypttab.assert_not_called() # Case 2, resource disk is encrypted but not mounted mock_is_luks.return_value = True self.assertEqual(self.resource_disk.try_remount(), True) mock_mount.assert_called_with(ResourceDiskUtil.RD_MAPPER_PATH) self.mock_disk_util.luks_open.assert_called_with(passphrase_file=self.mock_passhprase_filename, dev_path=ResourceDiskUtil.RD_DEV_PATH, mapper_name=ResourceDiskUtil.RD_MAPPER_NAME, header_file=None, uses_cleartext_key=False) mock_add_rd_to_crypttab.assert_called() # Case 2, when the resoure disk mount fails mock_mount.return_value = False self.assertEqual(self.resource_disk.try_remount(), False) mock_mount.assert_called_with(ResourceDiskUtil.RD_MAPPER_PATH) # Case 3, The RD is encyrpted and mounted. mock_crypt_mounted.return_value = True mock_mount.reset_mock() mock_add_rd_to_crypttab.reset_mock() mock_mount.return_value = True self.assertEqual(self.resource_disk.try_remount(), True) mock_mount.assert_not_called() mock_add_rd_to_crypttab.assert_not_called() # Case 4, The RD is plain mounted already and there is no passphrase mock_plain_mounted.return_value = True self.resource_disk.passphrase_filename = None self.assertEqual(self.resource_disk.try_remount(), True) # Case 5, The RD is not plain mounted but the mount fails for some reason. mock_mount.return_value = False mock_plain_mounted.return_value = False self.assertEqual(self.resource_disk.try_remount(), False) mock_mount.assert_called_once_with(ResourceDiskUtil.RD_DEV_PATH) # Case 6, The RD is not plain mounted and mount succeeds mock_mount.return_value = True self.assertEqual(self.resource_disk.try_remount(), True) mock_mount.assert_called_with(ResourceDiskUtil.RD_DEV_PATH) @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._is_crypt_mounted', return_value=False) @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil._is_plain_mounted', return_value=True) @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil.encrypt_format_mount') @mock.patch('main.ResourceDiskUtil.ResourceDiskUtil.try_remount') def test_automount(self, mock_try_remount, mock_encrypt_format_mount, mock_is_plain_mounted, mock_is_crypt_mounted): # Case 1: try_remount succeds mock_try_remount.return_value = True self.assertEqual(self.resource_disk.automount(), True) mock_try_remount.assert_called_once() # Case 2: try_remount fails and public settings is non-EFA: mock_try_remount.return_value = False # Case 2.x: these are basically gonna be a bunch of tests for "is_encrypt_format" self.resource_disk.public_settings = {} self.assertEqual(self.resource_disk.automount(), True) mock_encrypt_format_mount.assert_not_called() self.resource_disk.public_settings = { CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryption} self.assertEqual(self.resource_disk.automount(), True) mock_encrypt_format_mount.assert_not_called() self.resource_disk.public_settings = { CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.DisableEncryption} self.assertEqual(self.resource_disk.automount(), True) mock_encrypt_format_mount.assert_not_called() # Case 3: EFA case. A try remount failure should lead to a hard encrypt_format_mount. self.resource_disk.public_settings = { CommonVariables.EncryptionEncryptionOperationKey: CommonVariables.EnableEncryptionFormatAll} mock_encrypt_format_mount.return_value = True self.assertEqual(self.resource_disk.automount(), True) mock_encrypt_format_mount.assert_called_once() # case 4: EFA case, but EFA fails for some reason mock_encrypt_format_mount.reset_mock() mock_encrypt_format_mount.return_value = False self.assertEqual(self.resource_disk.automount(), False) mock_encrypt_format_mount.assert_called_once() ================================================ FILE: VMEncryption/test/test_utils.py ================================================ import os class MockDistroPatcher: def __init__(self, name, version, kernel): self.distro_info = [None] * 2 self.distro_info[0] = name self.distro_info[1] = version self.kernel_version = kernel def mock_dir_structure(artificial_dir_structure, isdir_mock, listdir_mock, exists_mock): """ Takes in an artificial directory structure dict and adds side_effects to mocks which are hooked to that directory example: artificial_dir_structure = { "/dev/disk/azure": ["root", "root-part1", "root-part2", "scsi1"], os.path.join("/dev/disk/azure", "scsi1"): ["lun0", "lun0-part1", "lun0-part2", "lun1-part1", "lun1"] } any string that has an entry in this dict is mocked as a directory. So, /dev/disk/azure and /dev/disk/azure/scsi1 are dicts. Everything else that is implied to exist is a file (e.g. /dev/disk/azure/root, /dev/disk/azure/root-part1, etc) For an example look at test_disk_util.test_get_controller_and_lun_numbers method NOTE: this method just modifies supplied mocks, it doesn't return anything. """ def mock_isdir(string): return string in artificial_dir_structure isdir_mock.side_effect = mock_isdir def mock_listdir(string): dir_content = artificial_dir_structure[string] return dir_content listdir_mock.side_effect = mock_listdir def mock_exists(string): if string in artificial_dir_structure: return True for dir in artificial_dir_structure: listing = artificial_dir_structure[dir] for entry in listing: entry_full_path = os.path.join(dir, entry) if string == entry_full_path: return True return string in artificial_dir_structure exists_mock.side_effect = mock_exists ================================================ FILE: docs/advanced-topics.md ================================================ # Advanced Topics ## Azure Templates Samples You can add a sample template into [Azure/azure-quickstart-templates](https://github.com/Azure/azure-quickstart-templates) to deploy your extension. ## Azure Powershell and CLI You can add the specific commands in Azure Powershell or CLI to install and enable your extension. ================================================ FILE: docs/contribution-guide.md ================================================ # Contribution Guide 3rd party partners are welcomed to contribute the Linux extensions. Before you make a contribution, you should read the following guide. ## Table of Contents * [**Overview**](./overview.md) * [Terminology](./overview.md#terminology) * [Requirements](./overview.md#requirements) * [Architecture Overview](./overview.md#architecture-overview) * **Development** * [Design Details](./design-details.md) * [Handler Artifacts](./design-details.md#handler-artifacts) * [Handler Lifecycle Management](./design-details.md#handler-lifecycle-management) * [Report Status and Heartbeat](./design-details.md#report-status-and-heartbeat) * [Logging](./design-details.md#logging) * [Utils](./utils.md) * [Sample Extension](./sample-extension.md) * [Advanced Topics](./advanced-topics.md) * [Azure Templates Samples](./advanced-topics.md#azure-templates-samples) * [Azure Powershell and CLI](./advanced-topics.md#azure-powershell-and-cli) * [**Handler Registration**](./handler-registration.md) * [Package and upload your extension](./handler-registration.md#package-and-upload-your-extension) * [Register your extension](./handler-registration.md#register-your-extension) * [**Test**](./test.md) * Test Matrix * ASM or ARM * Azure Templates * Jenkins * [Document](./document.md) ================================================ FILE: docs/design-details.md ================================================ # Design Details This page descibes the design details of the extension. You can write an extension from scrach folloing this page. ## Handler Artifacts An Azure Extension Handler is composed of the following artifacts: 1. **Handler Package**: This is the package that contains your Handler binary files and all standard static configuration files. This package is registered with the Azure ecosystem. 2. **Handler Environment**: This is the set of files and folders that WALA sets up for the Handlers to use at runtime. These files can be used for communicating with WALA (heartbeat and status) or for writing debugging information (logging). The details of handler environment created by WALA is discussed in the section [Handler Environment](#handler-environment). 3. **Handler Configuration**: This is a configuration file that contains various settings needed to configure this Handler at runtime. Extension configuration is the input provided by the end user based on the schema provided by the handler publisher during registration. For example, a handler might get the client authentication details for writing logs to his storage account via the handler configuration. ### Handler Package The Handlers are packaged as simple zip files for being registered in the Azure ecosystem. The zip file is supposed to contain the following: * The handler binaries. * HandlerManifest.json file that is used by WALA to manage the handler. This HandlerManifest.json file should be located in the root folder of the zip file. The JSON file should be of the format: ``` [{ "version": 1.0, "handlerManifest": { "installCommand": "", "uninstallCommand": "", "updateCommand": "", "enableCommand": "", "disableCommand": "", "rebootAfterInstall": , "reportHeartbeat": } }] ``` The above JSON file provides a list of all commands that will be executed by WALA for managing various handlers on the VM. * **Version**: indicates the version of the protocol which should be used by WALA to deserialize this JSON. * **Install\Uninstall\Update\Enable\Disable** point to the command line that will be executed by WALA in various scenarios. The paths of the command line provided in HandlerManifest.json should be relative to the root directory of the handler. The current working directory of the handler is the path of the root folder of the handler. All these command lines are launched as LOCAL SYSTEM with administrative privileges. **Note**: It is valid for multiple commands in the HandlerManifest to point to the same command line. For e.g. the install and Update command might point to the same binary with same parameters. * **RebootAfterInstall** notifies WALA if a reboot is required to complete the installation of a handler. Handlers should not reboot the system independently to avoid interfering with each other. * **ReportHearbeat** indicates WALA if the handler will be reporting heartbeat or not. The details of heartbeat and status is discussed in section [Heartbeat Reporting](#heartbeat_reporting). **Note:** All of the fields in the JSON specified above are required fields and registration of the handler with Azure will fail if one of these fields in not specified. The explanation of the meaning of various fields in the JSON with respect to WALA is provided in the below sections. An example of the directory structure of the zip file for a handler is: ``` SampleExtension.zip |-HandlerManifest.json |-install.py |-uninstall.py |-enable.py |-disable.py |-update.py ``` A sample HandlerManifest.json for the above sample handler would be: ``` [{ "version": 1.0, "handlerManifest": { "installCommand": "./install.py", "uninstallCommand": "./uninstall.py", "updateCommand": "./update.py", "enableCommand": "./enable.py", "disableCommand": "./disable.py", "rebootAfterInstall": false, "reportHeartbeat": true } }] ``` ### HandlerEnvironment When WALA installs a handler on the VM, it creates a bunch of files and folders that are needed by the handler at runtime for various purposes. The location of all these files and folders are communicated to the handler via the HandlerEnvironment.json file. HandlerEnvironment.json is the file that is created under the root directory where the handler is unpackaged. The structure of HandlerEnvironment.json is: ``` [{ "version": 1.0, "handlerEnvironment": { "logFolder": "", "configFolder": "", "statusFolder": "", "heartbeatFile": "", } }] ``` * **version** - contains the version of the protocol that WALA is abiding with. In the initial release the only supported version is 1.0. * **handlerEnvironment** – This is the object that encapsulates all the properties of a handler defined in the version 1.0 of the protocol. * **logFolder** – contains the location where the handler should put its log files that might be needed to debug any customer issues. The advantage of putting log files under the folder directed by this location is that these files can be automatically retrieved from the customers VM by using a tool, without actually logging into the VM and copying them over manually. * **configFolder** – contains the location where the handler will get its configuration settings file. * **statusFolder** – contains the location where the handler is supposed to write back a file with a structured status of the current state of the work being done by the handler. * **heartbeatFile** - this is the file that is used to communicate the heartbeat of the handler back to WALA. Errors while reading HandlerEnvironment.json – In rare cases a handler might encounter errors when trying to read the HandlerEnvironment.json file, since WALA might be writing the file at the same time as well. The handler should be capable of handling such errors. Our recommendation for handler publishers would be to have a retry logic with some sort of backoff. ### Handler Configuration There are scenarios when a handler needs some user input parameters to configure its handler. All such user provided input is communicated from WALA to the handler via the configuration file. For e.g. a handler might require the user to provide the account name and the key of a user storage account where the logs will be saved. This account information can be passed by the user to the handler via the configuration file. #### Configuration File Structure The configuration file should be a valid JSON with the only property of the root object as `handlerSettings` and with two child objects `protectedSettings` and `publicSettings`. Apart from that the complete schema of the handler configuration file under the `publicSettings`\`protectedSettings` property is defined by the handler publisher during the registration process. When a call to add the handler to the VM is made, the user needs to provide a configuration that complies with the structure that the handler publisher had provided during registration. **Managing user secrets**: There may be parts of the handler configuration that contain user secrets (like passwords, storage keys, etc). These secrets in general should never be persisted in plain text to prevent accidently disclosure. To support this concept, the Azure Extension Handler publishers can allow users to store all or part of the handler configuration in a protected section of the config. All settings under this section are encrypted by an X509 certificate before being sent over to the VM. The WALA will persist the protected settings as encrypted only and will provide the thumbprint of the certificate that needs to be used for decrypting this information. To extract the setting, the handler will need to retrieve the certificate from the Local Machine store and decrypt the settings using the certificate private key. The publisher of the Azure Extension Handler decides what, if any, part of the configuration should be protected in this manner. A sample configuration file would look like: ``` { "handlerSettings": { "protectedSettings": { "storageaccountname": "MY SECRET STORAGE ACCOUNT NAME", "storageaccountkey": "MY SECRET STORAGE ACCOUNT KEY" }, "publicSettings": { "MyHandlerConfiguration": { "configurationChangePollInterval": " ", "overallQuotaInMB": 12 }, "MyHandlerInfrastructureLogs": { "scheduledTransferLogLevelFilter": "Verbose", "bufferQuotaInMB": "100", "scheduledTransferPeriod": "PT1M" } } } } ``` In the above example the storageaccountname and storageaccountkey are protected secrets. When these secrets are persisted on a file in the VM for consumption by the handler the protected section would be encrypted and base64 encoded. In the case of above settings, the configuration file for the above sample on the VM would look like: ``` { "handlerSettings": { "protectedSettingsCertThumbprint": "a811c3f4058542418abb", "protectedSettings": "ICB7DQogICAgInN0b3JhZ2VhY2NvdW50IiA6ICJbcGFyY W1ldGVycy5TdG9yYWdlQWNjb3VudF0iLA0KICB9LA0K", "publicSettings": { "DiagnosticMonitorConfiguration": { "configurationChangePollInterval": " ", "overallQuotaInMB": 12 }, "DiagnosticInfrastructureLogs": { "scheduledTransferLogLevelFilter": "Verbose", "bufferQuotaInMB": "100", "scheduledTransferPeriod": "PT1M" } } } } ``` #### Location of Handler Configuration The location where the configuration setting files will be written can be retrieved by the "configFolder" property in the HandlerEnvironment.json file. #### Handler Configuration Filename Whenever a new configuration is received, WALA will write the configuration settings file named .settings under the configFolder with the configuration provided by the user and launches [the enable command of the handler](#enable). The handler is expected to retrieve the last sequence number of the configuration file written by WALA bylooking under the configfolder directory for the highest sequence number. This sequence number can then be used to apply the latest user provided configuration settings to the handler. ## Handler Lifecycle management ### Add a new handler on the VM (Install and Enable) When a handler is requested on a VM by the user, WALA will do the following inside the VM: 1. Download the handler package zip from Azure repository to `/var/lib/waagent`. 2. Unzip the package under a unique location corresponding to the handler identity. The handler should not take any dependency on the location where the handler package is unpacked, since this location might change in future depending on future requirements. Currently, the unique location is formatted as `.-`. 3. Create the configuration, logging and status folders for the handler. 4. Create the .settings file with the initial configuration. 5. Creates the HandlerEnvironment.json file under the root folder where the handler is unpacked. 6. Parse the HandlerManifest.json file and execute the install command in a separate process. 7. The install command is executed in the process with the root privileges. 8. If there are multiple handlers that are being installed, WALA will download and unzip them in parallel but will invoke the install command sequentially only. 9. WALA will wait for the installation to complete and monitor the exit code of the install process. 10. If the install process exits **SUCCESSFULLY** (exit code 0), WALA maintains state that the handler was installed successfully and does not run the install command for the same handler again ever unless the handler has been uninstalled first. * WALA will wait for a maximum of 5 minutes before timing out the install process and considering the install to be failed. 11. If the install process exits **SUCCESSFULLY**, WALA will provide the handler configuration settings in the defined location and launch the `Enable` command in a separate process that runs with root privileges. 12. If the install process exits **UNSUCCESSFULLY**, WALA will retry to install the handler under two circumstances: * When WALA receives a new goal state triggered by a user action. (e.g. Adding\removing\updating any handler or updating handler configuration etc.) * When WALA restarts (which should only happen when the machine itself is rebooted). #### Install command In the install command, the handler is expected to install its processes and services on the system and create the necessary setup that is required for the handler to run at runtime. ### Remove a handler from the VM (Disable and Uninstall) When a user explicitly requests to remove the handler from the VM, WALA will execute the following actions: 1. The disable command specified in the HandlerManifest.json will be executed in a separate process that runs with root privileges. The handler is expected to complete the pending tasks and then stop any processes or services related to the handler that have been running on the machine. * WALA will wait for a max of 5 minutes for the disable process to finish before timing out to the next steps. 2. The uninstall command will be invoked in a separate process that runs with root privileges. WALA will wait for a maximum of 5 mins for the uninstall process to finish. 3. WALA will remove all the package binaries and configuration files that were associated with the handler. The handler log files will be maintained on the machine for any future debugging purposes. ### Disable A user might explicitly request to disable a handler without uninstalling it. On disable WALA will execute the disable command in a separate process with root privileges. On the execution of the disable command the handler is expected to complete the pending tasks and then stop any processes or services related to the handler that have been running on the machine. WALA will wait a max of 5 mins for the disable process to finish before timing out to the next steps. ### Enable A user might explicitly request to enable a handler that has been previously disabled. On enable WALA will execute the enable command in a separate process with root privileges. The enable command will be invoked every time the machine reboots or the machine receives a new configuration settings file. **Note:** Unlike the install state, WALA will not maintain the enabled\disabled state of the handler. Every time the machine restarts (which in turn will restart WALA) or a new goal state is received, WALA will try to set the machine to the latest goal state. Thus it might invoke the enabled\disabled commands multiple times even if the handler is already enabled\disabled. So the enable and disable commands need to be idempotent i.e. if the handler is already enabled and the enable command is invoked again, the command should check if all the processes are running as expected, if yes, then the command should just exit with a success code. ### Update There are two scenarios when an update can happen: * The user triggers an explicit update of the handler. * The handler is updated on Azure repository and it automatically gets picked up by WALA. In both these cases WALA will identify that a handler with the same name and publisher and a lower version is already installed on the machine. 1. It will download the updated version of the handler from Azure repository, unpack it under the handler identity folder. 2. WALA will call disable on the existing handler with the lower version. 3. WALA will invoke the update command in the newly downloaded packages under a separate process with root privileges. During update the handler has an opportunity to transfer any state information from the previous handler. 4. WALA will invoke the uninstall command on the existing handler with lower version. 5. WALA will invoke the enable command on the newly downloaded package ## Reporting Status and Heartbeat Microsoft Azure provides two facilities to report back the health of the handler and the status of the operations being performed by it. 1. **Heartbeat**: Heartbeat channel should be used to report the health of the handler itself. Providing heartbeat is an optional facility that the handler can opt into by setting the reportHearbeat property to true in the HandlerManifest. Heartbeat is generally expected to be reported by long running services or processes. For e.g. an antivirus handler service might use the heartbeat channel to indicate if its service has stopped for some reason. 2. **Configuration Status**: Status channel should be used to report the success or failures of any operations that were conducted when applying the new configuration provided by the user. For e.g. Diagnostics agent might report issues connecting to the storage account via this channel. The WALA collects the heartbeat and status information for all handlers and aggregates them into VM health which is returned to the user when he queries for it via the GetDeployment RDFE API call. ### Heartbeat reporting The handler that have opted into reporting heartbeat are supposed to report it via the file specified in the heartbeat property of the HandlerEnvironment file. The structure of the heartbeat file should be: ``` [{ "version": 1.0, "heartbeat" : { "status": "", "code": , "Message": "" } }] ``` Various fields in the above JSON document correspond to the following: * **Version** – This is the version of the protocol being used to communicate heartbeat to WALA. Currently the only version WALA understands is 1.0. * **Heartbeat** – This object encapsulates all the heartbeat related information for the handler. * **Status** – The current status of the handler. The only valid values are “ready” and “notready”. * **Code** – The status code the handler. This is an optional field. * **Message** – A human readable\actionable error message for the user. This is an optional field. Handlers can report successful heartbeat by setting the status to "ready". To report repeated successful heartbeats, the handler can just change the last modified timestamp of this file. The status field only needs to be changed to "notready" if the handler has encountered some error\exception condition while executing. For e.g. If after the handler is installed and before the first configuration settings file is processed, if there is an exception, it can be reported via the status section in the heartbeat file. WALA will read the heartbeat file once every 2 minutes to check if the plugin is running or not. If the last modified timestamp is within the last 1 minute and the status is set to "ready" then WALA will consider the plugin to be working properly. If the last modified timestamp is older than 10 minutes, WALA will consider the plugin handler to be unresponsive. If the last modified timestamp is between 1 minute and 10 minute, WALA will consider the plugin to be in "Unknown" state. If the status is set to "NotReady", the error code and the message will be returned back to the user in the next GetDeployment call. A sample heartbeat file would look like: ``` [{ "version": 1.0, "heartbeat" : { "status": "ready", "code": 0, "Message": "Sample Handler running. Waiting for a new configuration from user." } }] ``` Errors while writing to the HeartBeat file – In rare cases a handler might encounter errors when trying to write the heartbeat file, since WALA might be reading the file at the same time as well. The handler should be capable of handling such errors. Our recommendation for handler publishers would be to have a retry logic with some sort of exponential backoff. ### Status reporting The handler can report status back to WALA by writing to the status file ".status" under the status folder specified in the HandlerEnvironment. The status file structure supported by WALA is: ``` [{ "version": 1.0, "timestampUTC": "", "status" : { "name": "", "operation": "", "configurationAppliedTime": "", "status": "", "code": , "message": { "id": "id of the localized resource", "params": [ "MyParam0", "MyParam1" ] }, "formattedMessage": { "lang": "Lang[-locale]", "message": "formatted user message" }, "substatus": [{ "name": "", "status": "", "code": , "Message": { "id": "id of the localized resource", "params": [ "MyParam0", "MyParam1" ] }, "FormattedMessage": { "Lang": "Lang[-locale]", "Message": "formatted user message" }, }] } }] ``` * **version** – indicates the version of the protocol being used for communicating the status back to WALA. * **timestampUTC** – The current time in UTC during which this status structure is being created. * **status** – The object that encapsulates the top level status about the configuration corresponding to what the status is being reported. * **status\name** – This property is optional. This property can be used by handlers to point to the VM workload name that are being managed by the handler. * **status\operation** – This property is optional. This property can be used by handlers to indicate the current operation being performed to enable the VM workload on the machine. * **status\configurationappliedtime** – This property is optional. This property can be used by handlers to indicate the last time the configuration corresponding to the current sequence number was successfully applied on the VM. * **status\status** – This property indicates the current status of the operation being performed. The only acceptable values are: Transitioning, error, success and warning. * **status\code** – A valid integer status code for the current operation. * **status\message** – This is an optional localized message that will be passed back to the user on a GetDeployment call via RDFE. * **status\message\id** – This is the message identifier, to be used for lookup of a localized message. Treated as a string. A symbolic id is preferred for human interpretation, for example Error_CannotConnect. The file that contains all the localized strings corresponding to the id would be provided by the handler author to Azure during registration. * **status\message\params** - This is an Ordered list of parameter (placeholder) values to be filled into the message template corresponding to the message id. The first Param is used for placeholder “{0}” in the message template (from the provided language resources); the second for placeholder “{1}”, etc. * **status\formattedMessage\lang** - The language/locale of the preformatted message. * **status\formattedMessage\message** - The human readable message that will be returned to the user. * **substatus** – An array of nested substatus objects that can be used by the handler to pass the substatus of complicated operations. The fields in the substatus array are supposed to be used in the same manner as they are used in the parent status array. Everytime a handler receives a new handler pack via a new configuration, it is expected to periodically report the status corresponding to that configuration in a file names . The status should be reported at least once every 2 minutes for the time when the handler is in (transitioning\Warning) state. Once the handler reaches a terminal state (success\error) it can stop reporting the status messages for that sequence number. Each time the handler has new status to report, it should overwrite file. The status provided in the status file should be an aggregate status (even if that status has been reported before) of all the operation performed for this configuration so far. If writing to the file fails, the handler should retry with backoff. The handler can write to the status file whenever it has something new to report. WALA will only read this status file after it has fed a new configuration to the handler and till the time the handler does not report status of a terminal state (success\error). During this time WALA will read the status file with a default frequency of 5 mins (configurable). A simple status report without localization from a handler would look like: ``` [{ "version": 1.0, "timestampUTC": "2013/11/13, 17:46:30.447", "status" : { "name": "enable wordpress", "operation": "installing wordpress", "status": "transitioning", "formattedMessage": { "Lang": "en", "Message": "Enable IIS on the VM." }, "substatus": [{ "name": "Wordpress plugin", "status": "success", "code": 0, "formattedMessage": { "Lang": "en", "Message": "Successfully downloaded wordpress plugin." } }, { "name": "Enable IIS", "status": "transitioning", "Message": "Turning windows feature for enabling IIS on." }] } }] ``` #### Localization Support To enable showing these messages in the user’s preferred language and, ideally, to enable multiple users to view the same captured execution status in different languages, we need to defer message resource lookup until the user queries for handler status. The current user’s preferred language would be retrieved from the HTTP header. Localization support is optional. If the handler does not wish to participate in localization they can just return the FormattedMessage strings in a default language which will be directly returned to the user. A localized status report would look like: ``` [{ "version": 1.0, "timestampUTC": "", "status" : { "name": "SharePointFrontEnd", "operation": "ResExtProvisioning", "status": "error", "code": 12, "Message": { "id": "1215", "params": [ "spo-sqldb.cloudapp.net", "JoeAdmin" ] } } }] ``` #### Localized message formatting As part of handler registration with Azure, a set of localization resources will be provided for looking up the status messages from the handler. A language/locale lookup sequence similar to the one for .NET resources will be applied, with the ultimate fallback being "en", a resource file for which must always be provided. The structure of the JSON resource files will be as follows. ``` [{ "version": 1.0, "lang": "lang[-locale]", "messages": [ { "id": "message id", "text": "Message text with {0}, {1} placeholder." }] }] ``` **Placeholder ordering** - The order of Status/Param values from the in-guest handler must be fixed (independent of language) and should correspond to the sequence of {n} placeholders in the English version of the message. If translation of a message in some language requires different order of the placeholders, the message template in the resource file for that language should have the placeholders reordered accordingly. To continue the earlier Status sample the message corresponding to id 1215, if in English we have: ``` Failed to establish connection to {0} as {1} ``` In German it might be: ``` {1} fehler beim Anschluss an {0} herzustellen ``` # Logging Handlers should use the folder provided in the "logfolder" property of the handler environment for writing logs required for debugging their handlers in lieu of any issues reported on a live customer VM. ================================================ FILE: docs/document.md ================================================ # Document A `README.md` is recommended in your extension directory. You can refer to [**README.md of CustomScript**](../CustomScript/README.md). What you should include in `README.md`: * Configuration schema (Public and Protected) * How to deploy the extension using Azure CLI or Azure Powershell * How to deploy the extension in ASM and ARM mode * How to deploy the extension using ARM templates * Configuration Examples * How to debug * Supported Linux Distributions ================================================ FILE: docs/handler-registration.md ================================================ # Handler Registration In this page, we will show you the steps to package and register your extensions to Azure repository. We assume that you have prepared your `SampleExtension` in `~/azure-linux-extensions/`. For registering a handler the following two components are required: * the handler package - The extension handler package needs to be uploaded to a storage location. * the definition xml file - This section gives an overview of some of the key elements that are required in the definition file. Also, the extension should be registered under the Publisher’s Azure Subscription. Prior to Registration, the subscription should be approved for publishing by Azure Runtime team. During the handler registration, you need specify the certificate of your Azure subscription. We provide some scripts to help package and register your extensions. ``` registration-scripts/ ├── api │   ├── add-extension.sh │   ├── check-request-status.sh │   ├── del-extension.sh │   ├── get-extension.sh │   ├── get-subscription.sh │   ├── list-extension.sh │   ├── params │   └── update-extension.sh ├── bin │   ├── add.sh │   ├── blob │   │   ├── list.sh │   │   └── upload.sh │   ├── check.sh │   ├── del.sh │   ├── get.sh │   ├── list.sh │   ├── subscription.sh │   └── update.sh ├── create_zip.sh ├── mooncake │   └── sample-extension-1.0.xml └── public └── sample-extension-1.0.xml ``` ## Package and upload your extension You can package your extension into a zip file using the following command. ``` cd ~/azure-linux-extensions/ ./registration-scripts/create_zip.sh SampleExtension/ 1.0.0.0 ``` Then you will get `SampleExtension-1.0.0.0.zip` in `build` directory. You should upload your extension to a downloadable storage, for e.g. [Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/). ``` bin/blob/upload.sh ~/azure-linux-extensions/build/SampleExtension-1.0.0.0.zip ``` ## Register your extension ### Prepare your subscription for registration The extension should be registered under the Publisher’s Azure Subscription. ### How to use the publish scripts The following scripts are executed in `registration-scripts` directory. You can configure `api/params` to change the endpoint (Public Azure or Mooncake). ### Definition File For registration, the publisher would have to provide the definition file. | Property | Description | Requirements | |:---------------:|:----- |:----- | | ProviderNamespace | This has to be a unique namespace per each subscription. The namespace is a combination of company team, team name (optional) and product name. E.g.: Microsoft.Azure.RemoteAcccess | Namespace cannot be empty, should be less than 256 chars and underscores cannot be used. | | Type | Name of the Extension Handler. The type indicate the purpose of the extension | Type cannot be empty, should be less than 256 chars and underscores cannot be used. | | Version | Version number of the handler. The combination of namespace, type and version uniquely identifies an extension. | The version number needs to be changed for every release. The format of version number has to be `...` Eg: 1.0.1.1 | | Label | The label of the extension | | | HostingResource | This should be either WebRole or WorkerRole or VmRole depending on whether it’s targeted for PaaS or IaaS. | These values are case sensitive. | | MediaLink | The blob url which has the Extension Package. | MediaLink value must point to a URL(either Http or Https) in a blob storage and is downloadable. | | Description | The description of the extension | | | IsInternalExtension | If this is set to "true", the handler is not visible for public use. It can be still accessed by referring to the Namespace, Type & Version combo. | Possible values are case-sensitive true or false | | Eula | If the software requires any additional EULAs, a link to the EULA should be provided. | | | PrivacyUri | If the software collects any data and transfers out the VM, then a additional Privacy document might be needed. | | | HomepageUri | A public URL that has usage information and contact information for customer support. | | | IsJsonExtension | Whether the Extension configuration is json format | It should always be "true". | | SupportedOS | The supported OS | It should always be "Linux". | | CompanyName | The company name | | You can prepare your sample definition file `public/sample-extension-1.0.xml`. ``` Microsoft.Love.Linux SampleExtension 1.0.0.0 VmRole Storage blob location of the Zip file Microsoft loves Linux false https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions true Linux Microsoft ``` ### Register the new extension ``` bin/add.sh public/sample-extension-1.0.xml ``` The operation of registration and unregistration is asynchronous. You can check the status of the operation using the following command. ``` bin/check.sh ``` You can get `` from the output of the registration operation. ### Update your extension Once the extension is published, any changes to the handler can be published as newer versions, using the update API. ``` bin/update.sh public/sample-extension-1.0.xml ``` Here is an overview of updates are done: * **Hotfixes** - Publisher should release hotfixes by changing the revision number. Eg: If the current version is 1.0.0.0, then the hotfixed version would be 1.0.0.1. All hotfixes would be automatically installed on the VM. * **Minor Version Changes** - Any minor features can be released as a minor update. E.g.: If the current version is 1.0.0.0, then a minor version update would be 1.1.0.0. If the client opts in for auto upgrade, all minor version changes would be automatically applied. * **Major Version Change** - Any breaking changes in the handler should be released as a major version update. The client has to explicitly request the major version changes. ### List your extensions **NOTE:** After registration and updating, you need to wait some time to **replicate** your extension. The wait time depends on the work load of the replication system, from half an hour to one day. You can get the replication status of the extension using the following command: ``` bin/list.sh ``` ### Unregister your extension ``` bin/delete.sh ``` Sample: ``` bin/delete.sh Microsoft.Love.Linux SampleExtension 1.0.0.0 ``` **NOTE:** Unregistration is supported for internal extensions only. You need to update your extension from public into internal before unregistration. ================================================ FILE: docs/overview.md ================================================ # Overview In order to make the Microsoft Azure IaaS VMs customizable, Microsoft Azure is releasing a set of capabilities which will enable users to automate software deployment and configuration on IaaS VMs. As a part of these capabilities, a protocol is being released which can be used by various existing VM customization products to integrate with the Microsoft Azure VM ecosystem. This document discusses the requirements to participate in Microsoft Azure VM ecosystem and provides a guide for integrating VM customization products with Microsoft Azure. ## Terminology | Teminology | Description | |:---------------:|:----- | | WALA | The Microsoft Azure component that runs inside the Linux VM and is responsible for managing the extension handlers. You can get the source code of WALA from https://github.com/Azure/WALinuxAgent. | | Handlers | Partner authored component to deliver software and configuration to the customer VM. This component needs to implement handler configuration and status contracts and be provided as a handler package. Generally a handler will consist of an Azure interoperability wrapper around an existing VM customization product. In the overview documents handlers are more broadly referred to as `extensions`. The term `handler` and `extension` are used somewhat interchangeably. | | Extension Pack | Specific job, workload, or script to be executed by the extension handler. | | Handler identity | An identifier used to uniquely define the handler. This identity is a tuple of , and | | Handler Manifest | A JSON based manifest that defines various properties needed by WALA to manage the handler. | ## Requirements To participate in the Microsoft Azure ecosystem, any VM customization product needs to create a handler that implements the WALA defined protocol to integrate with the Azure ecosystem. The basic requirements for creating a handler that implements the Azure protocol are: 1. Handler Packaging – The Handler should be packaged as a zip file. This zip package should contain all the binaries related of the handler and HandlerManifest. This package needs to be registered with the Azure image repository. Azure image repository is responsible for managing all versions of all the handlers that are registered with the Azure ecosystem. 2. Handler Environment - Handler needs the capability to read the environment file in the format that WALA defines. The environment file defines the locations of various files and folder that the handler needs to use for reading configuration and writing back heartbeat and status. 3. Handler Configuration – Various extension packs that the handler needs to manage are passed to the handler in form of configuration settings. For example if a script is needed by the handler to install an extension, that script is passed to it via the handler configuration file. The handler should have the ability to read this file in the format defined by the Azure Agent and should be able to execute its contents and report the status of that execution with a frequency that complies with the Azure Agent protocol. 4. Handler heartbeat and status – The handler is supposed to report the status of the most recently executed configuration with a frequency that complies with the Azure Agent protocol. In addition to status, if the handler opts into reporting heartbeat it needs to report the heartbeat for the complete lifetime of the handler on the VM with a frequency that complies with the Azure Agent protocol. ## Architecture Overview The below diagram gives an overview of how the handlers are supposed to interact with the Azure ecosystem. ![Architecture Overview](./architecture.jpg) ================================================ FILE: docs/sample-extension.md ================================================ # Sample Extension In this page, we offer a sample extension using [Utils](./utils.md). After this section, you can get the following directory: ``` SampleExtension/ ├── disable.py ├── enable.py ├── HandlerManifest.json ├── install.py ├── references ├── uninstall.py └── update.py ``` ## HandlerManifest.json ``` [{ "name": "SampleExtension", "version": 1.0, "handlerManifest": { "installCommand": "./install.py", "uninstallCommand": "./uninstall.py", "updateCommand": "./update.py", "enableCommand": "./enable.py", "disableCommand": "./disable.py", "rebootAfterInstall": false, "reportHeartbeat": false } }] ``` ## enable.py 1. Get the paramter `name` in the public settings. 2. Log the `name` into `extension.log`. ## references This file is used to package the extension using [create_zip.sh](https://github.com/Azure/azure-linux-extensions/blob/master/script/create_zip.sh). You can put `Utils` in `references`. Then `create_zip.sh` will put the direcotry `SampleExtension` and `Utils` into `SampleExtension-1.0.zip`. ================================================ FILE: docs/test.md ================================================ # Test ## Test Matrix You should test your extension in the distros which you want to support. Here is the distro list: * Ubuntu 12.04 and higher * CentOS 6.5 and higher * Oracle Linux 6.4.0.0.0 and higher * openSUSE 13.1 and higher * SUSE Linux Enterprise Server 11 SP3 and higher * FreeBSD * CoreOS You can choose some or all of them to support. ## ASM or ARM It's important to understand that Azure currently has two deployment models: Resource Manager, and classic. Make sure you understand [deployment models and tools](https://azure.microsoft.com/en-us/documentation/articles/azure-classic-rm/) before working with any Azure resource. ## Azure Templates If you decide to support the scenario of deploying your extension using ARM Templates, you need to test it. ## Continuous Integration There are many tools to do the CI work, for e.g. Jenkins, Concourse and so on. ================================================ FILE: docs/utils.md ================================================ # Utils You can write an extension from scrach using your favourate language following [Design Details](./design-details.md). The utils we offer are optional. They are writen in Python, and they can accelerate your development. Without them, you need to handle the protocal between WALA and extensions by yourself. ## HandlerUtils [HandlerUtils.py](https://github.com/Azure/azure-linux-extensions/blob/master/Utils/HandlerUtil.py) handles the protocal between WALA and extensions, status and heartbeat reporting, and the logging. * Get your settings * `get_public_settings()` method returns the public settings * `get_protected_settings()` method returns the protected settings which have been decrypted. * Status reporting * `do_status_report` method reports the status, but not exists. * `do_exit` method reports the status and exists. * Logging * HandlerUtils.py will put the logs into the log file `extension.log` which is located in `logFolder` of `handlerEnvironment.json`. * The method `log` and `error` can be used. ## WAAgentUtil WAAgentUtil.py helps to load the source of [WALA](https://github.com/Azure/WALinuxAgent). You can use the function in WALA, for e.g. GetFileContents. ================================================ FILE: go.mod ================================================ module github.com/ChrisCoe/azure-linux-extensions go 1.21 ================================================ FILE: go.sum ================================================ ================================================ FILE: registration-scripts/api/add-extension.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir echo $1 curl -v -X 'POST' -H "$VERSION" -H 'Content-Type: application/xml' -E $CERT -d@$1 $ENDPOINT/$SUBSCRIPTION/services/extensions ================================================ FILE: registration-scripts/api/check-request-status.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir echo $1 curl -v -X 'GET' --keepalive-time 30 --user-agent ' Microsoft.WindowsAzure.Management.Compute.ComputeManagementClient/0.9.0.0 WindowsAzurePowershell/v0.8.0' -H 'x-ms-version: 2014-06-01' -H 'Content-Type: application/xml' --insecure -E $CERT --data-binary @$1 $ENDPOINT/$SUBSCRIPTION/operations/$1 ================================================ FILE: registration-scripts/api/del-extension.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir curl -v -X 'DELETE' -H "$VERSION" -H 'Content-Type: application/xml' -E $CERT $ENDPOINT/$SUBSCRIPTION/services/extensions/$1/$2/$3 ================================================ FILE: registration-scripts/api/get-extension.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir curl -v -X 'GET' -H "$VERSION" -H 'Content-Type: application/xml' -E $CERT $ENDPOINT/$SUBSCRIPTION/services/resourceextensions/$1/$2 ================================================ FILE: registration-scripts/api/get-subscription.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir echo "GET $ENDPOINT/$SUBSCRIPTION" curl -v -X 'GET' --keepalive-time 30 --user-agent ' Microsoft.WindowsAzure.Management.Compute.ComputeManagementClient/0.9.0.0 WindowsAzurePowershell/v0.8.0' -H 'x-ms-version: 2014-06-01' -H 'Content-Type: application/xml' --insecure -E $CERT $ENDPOINT/$SUBSCRIPTION ================================================ FILE: registration-scripts/api/list-extension.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir curl -v -X 'GET' -H 'x-ms-version: 2014-06-01' -H 'Content-Type: application/xml' -E $CERT $ENDPOINT/$SUBSCRIPTION/services/publisherextensions ================================================ FILE: registration-scripts/api/params ================================================ AZURE_PRODUCTION=1 #MOONCAKE_PRODUCTION=1 if [ $MOONCAKE_PRODUCTION -eq 1 ] ; then export ENDPOINT="https://management.core.chinacloudapi.cn" export SUBSCRIPTION="REPLACE-ME" export VERSION='x-ms-version: 2014-06-01' export CERT="REPLACE-ME" export CONN_STR="REPLACE-ME" elif [ $AZURE_PRODUCTION -eq 1 ] ; then export ENDPOINT="https://management.core.windows.net" export SUBSCRIPTION="REPLACE-ME" export VERSION='x-ms-version: 2014-06-01' export CERT="REPLACE-ME" export CONN_STR="REPLACE-ME" fi echo ENDPOINT: $ENDPOINT >&2 echo SUBSCRIPTION: $SUBSCRIPTION >&2 echo CERT: $CERT >&2 echo VERSION: $VERSION >&2 ================================================ FILE: registration-scripts/api/update-extension.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script source params export script=`pwd` cd $original_dir curl -v -X 'PUT' -H "$VERSION" -H 'Content-Type: application/xml' -E $CERT -d@$1 $ENDPOINT/$SUBSCRIPTION/services/extensions?action=update ================================================ FILE: registration-scripts/bin/add.sh ================================================ #!/bin/bash original_dir=`pwd` script=$(dirname $0) root=$script/.. cd $root root=`pwd` cd $original_dir echo "Add extension: $1" $root/api/add-extension.sh 2>/tmp/restoutput $1 | sed -e 's/>\n" tail /tmp/restoutput echo "====================" echo "More info is saved in /tmp/restoutput" ================================================ FILE: registration-scripts/bin/blob/list.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script/../../api source params export script=`pwd` cd $original_dir azure storage blob list -c $CONN_STR extensions ================================================ FILE: registration-scripts/bin/blob/upload.sh ================================================ #!/bin/bash original_dir=`pwd` script=`dirname $0` cd $script/../../api source params export script=`pwd` cd $original_dir zip_file=$(readlink -f $1) if [ ! -f $zip_file ] ; then echo "File not found: $zip_file" exit 1 fi file_name=$(basename $zip_file) echo "Uploading $zip_file to azure..." azure storage blob upload -c $CONN_STR $zip_file extensions $file_name ================================================ FILE: registration-scripts/bin/check.sh ================================================ #!/bin/bash original_dir=`pwd` script=$(dirname $0) root=$script/.. cd $root root=`pwd` cd $original_dir echo "Check Request: $1" $root/api/check-request-status.sh 2>>/tmp/restoutput $1 | sed -e 's/>\n/tmp/restoutput $1 $2 $3| sed -e 's/>\n" tail /tmp/restoutput echo "====================" echo "More info is saved in /tmp/restoutput" ================================================ FILE: registration-scripts/bin/get.sh ================================================ #!/bin/bash original_dir=`pwd` script=$(dirname $0) root=$script/.. cd $root root=`pwd` cd $original_dir echo "Get extension: $1 $2" $root/api/get-extension.sh 2>/tmp/restoutput $1 $2 | sed -e 's/>\n/tmp/restoutput | sed -e 's/>\n/<\/ExtensionImage>\n/g' ================================================ FILE: registration-scripts/bin/subscription.sh ================================================ #!/bin/bash original_dir=`pwd` script=$(dirname $0) root=$script/.. cd $root root=`pwd` cd $original_dir echo "Get subscription" $root/api/get-subscription.sh 2>>/tmp/restoutput | sed -e 's/>\n/tmp/restoutput $1 | sed -e 's/>\n" tail /tmp/restoutput echo "====================" echo "More info is saved in /tmp/restoutput" ================================================ FILE: registration-scripts/create_zip.sh ================================================ #!/bin/bash # # This script is used to set up a test env for extensions # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # if [ $# != 2 ] ; then echo "" echo " Usage: $0 " echo " Example: $0 SampleExtension 1.0.0.0" echo "" exit 1 fi if [ ! -d $1 ] ; then echo "" echo " Error: Couldn't find dir: $1" echo "" exit 1 fi curr_dir=`pwd` ext_dir=$1 cd $ext_dir ext_dir=`pwd` cd $curr_dir script=$(dirname $0) root=$script/.. cd $root root=`pwd` util_dir=$root/Utils build_dir=$root/build if [ ! -d $build_dir ] ; then mkdir $build_dir fi ext_name=`echo $1 | sed 's/\/$//'` ext_version=$2 ext_full_name=$ext_name-$ext_version tmp_dir=$build_dir/$ext_full_name echo "Create zip for $ext_name version $ext_version" echo "Creat tmp dir: $tmp_dir" mkdir $tmp_dir echo "Copy files..." cp -r $ext_dir/* $tmp_dir rm $tmp_dir/references echo "Copy dependecies..." cat $ext_dir/references cat $ext_dir/references | xargs cp -r -t $tmp_dir echo "Switch to tmp dir..." cd $tmp_dir echo "Remove test dir..." rm -r test rm -r */test echo "Remove *.pyc..." find . -name "*.pyc" | xargs rm -f echo "Create zip..." zip -r $build_dir/$ext_full_name.zip . echo "Delete tmp dir..." rm $tmp_dir -r echo "Done!" ================================================ FILE: registration-scripts/mooncake/sample-extension-1.0.xml ================================================ Microsoft.Loves.Linux SampleExtension 1.0.0.0 VmRole Storage blob location of the Zip file Microsoft loves Linux false https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions true Linux Microsoft ================================================ FILE: registration-scripts/public/sample-extension-1.0.xml ================================================ Microsoft.Loves.Linux SampleExtension 1.0.0.0 VmRole Storage blob location of the Zip file Microsoft loves Linux false https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions/blob/1.0/LICENSE-2_0.txt https://github.com/Azure/azure-linux-extensions true Linux Microsoft ================================================ FILE: script/0.settings ================================================ {"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"TEST","protectedSettings":"MIIByAYJKoZIhvcNAQcDoIIBuTCCAbUCAQAxggFxMIIBbQIBADBVMEExPzA9BgoJkiaJk/IsZAEZFi9XaW5kb3dzIEF6dXJlIFNlcnZpY2UgTWFuYWdlbWVudCBmb3IgRXh0ZW5zaW9ucwIQJ1fD4ZQMF7RKAOgzHVJRRDANBgkqhkiG9w0BAQEFAASCAQBrnH4vyuPreCPD53g4e/ixZ7F9+iHzG3Vp4R7LnZoFLVejLcPfxQ1yhaDtXiIAXs19LfnwukbSe2gxpEIkNqohSh4EvRn2RI2ss4Lmmp69qnccr3g8/uHdgYBKUxyZbG+Ul2tjzcu173uOKpr6fSrAGKyGX0KqPCBFMD7vxhem3sd/9oQwfsxXUvkl3zkFioOP5oor6BKvfMQ8kxRv0UfvXF0mqDzXLF8/vQ6kexqglAH+L8L5dcXFF1+D/WyNUkZJOr4ax4BMgtrV/HGoWoNkjxmFrRcsiEpJ2JGCPduAuWUYHrLjV59Jjf30pszN2D/K1naYwNDY79zRDm/8CTJEMDsGCSqGSIb3DQEHATAUBggqhkiG9w0DBwQIUStUI4paw9uAGHvktyCyAIwMBP/AB5iOs34BuT5vXdGH7g=="}}]} ================================================ FILE: script/HandlerEnvironment.json ================================================ [{ "name": "VMAccess", "seqNo": "1", "version": 1.0, "handlerEnvironment": { "logFolder": "/var/log/azure/VMAccess/1.0", "configFolder": "/var/lib/waagent/VMAccess-1.0/config", "statusFolder": "/var/lib/waagent/VMAccess-1.0/status", "heartbeatFile": "/var/lib/waagent/VMAccess-1.0/heartbeat.log" } }] ================================================ FILE: script/create_zip.sh ================================================ #!/bin/bash # # This script is used to set up a test env for extensions # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # if [ ! $1 ] || [ ! $2 ] || [ ! $3 ] ; then echo "" echo " Usage: create_zip.sh " echo "" exit 1 fi if [ ! -d $1 ] ; then echo "" echo " Error: Couldn't find dir: $1>" echo "" exit 1 fi curr_dir=`pwd` ext_dir=$1 ext_name=$2 ext_version=$3 cd $ext_dir ext_dir=`pwd` cd $curr_dir script=$(dirname $0) root=$script/.. cd $root root=`pwd` echo $ext_name echo $ext_version util_dir=$root/Utils build_dir=$root/build if [ ! $ext_name ] ; then echo "" echo " Error: Couldn't detect extention name: $ext_name" echo "" exit 1 fi if [ ! $ext_version ] ; then echo "" echo " Error: Couldn't detect extention version: $ext_version" echo "" exit 1 fi if [ ! -d $build_dir ] ; then mkdir $build_dir fi ext_full_name=$ext_name-$ext_version tmp_dir=$build_dir/$ext_full_name echo "Create zip for $ext_name version $ext_version" echo "Creat tmp dir: $tmp_dir" mkdir $tmp_dir echo "Copy files..." cp -r $ext_dir/* $tmp_dir rm $tmp_dir/references echo "Copy dependecies..." cat $ext_dir/references cat $ext_dir/references | xargs cp -r -t $tmp_dir echo "Switch to tmp dir..." cd $tmp_dir echo "Remove test dir..." rm -r test rm -r */test rm *.pyc echo "Create zip..." zip -r $build_dir/$ext_full_name.zip . echo "Delete tmp dir..." rm $tmp_dir -r echo "Done!" ================================================ FILE: script/mkstub.sh ================================================ #!/bin/bash # # This script is used to create stub for unit test # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # if [ ! $1 ] ; then echo "" echo " Usage: mkstub.sh " echo "" exit 1 fi if [ ! -d $1 ] ; then echo "" echo " Error: Couldn't find dir: $1>" echo "" exit 1 fi ext_dir=$1 ext_meta=$ext_dir/HandlerManifest.json if [ ! -f $ext_meta ] ; then echo "" echo " Error: Couldn't find \"HandlerManifest.json\" file under $ext_dir" echo "" exit 1 fi cur_dir=`pwd` script=$(dirname $0) root=$script/.. cd $root root=`pwd` waagent_path='/usr/sbin/waagent' waagent_lib_dir='/var/lib/waagent' ext_log_dir='/var/log/azure' ext_name=`grep 'name' $ext_meta | sed 's/[\"| |,]//g' |gawk -F ':' '{print $2}'` ext_version=`grep 'version' $ext_meta | sed 's/[\"| |,]//g' |gawk -F ':' '{print $2}'` ext_full_name=$ext_name-$ext_version ext_dir=$waagent_lib_dir/$ext_full_name ext_status_dir=$ext_dir/status ext_config_dir=$ext_dir/config ext_env_json=$ext_dir/HandlerEnvironment.json test_cert_file=$waagent_lib_dir/TEST.crt test_pk_file=$waagent_lib_dir/TEST.prv ovf_env_file=$waagent_lib_dir/ovf-env.xml if [ ! -f $waagent_path ] ; then echo "Download latest waagent code" wget https://raw.githubusercontent.com/Azure/WALinuxAgent/2.0/waagent -O $waagent_path chmod +x $waagent_path fi if [ ! -d $waagent_lib_dir ] ; then echo "Create lib dir" mkdir $waagent_lib_dir fi if [ ! -d $ext_dir ] ; then echo "Create extension dir" mkdir $ext_dir fi if [ ! -d $ext_config_dir ] ; then echo "Create extension config dir" mkdir $ext_config_dir fi if [ ! -d $ext_status_dir ] ; then echo "Create extension status dir" mkdir $ext_status_dir fi if [ ! -f $ext_env_json ] ; then echo "Create HandlerEnvironment.json file" cp $script/HandlerEnvironment.json $ext_env_json fi if [ ! -f $test_cert_file ] ; then echo "Create test cert file" cp $script/test.crt $test_cert_file fi if [ ! -f $test_pk_file ] ; then echo "Create test pk file" cp $script/test.prv $test_pk_file fi if [ ! -f $ovf_env_file ] ; then echo "Create ovf-env.xml file" cp $script/ovf-env.xml $ovf_env_file fi if [ ! -f $ext_config_dir/0.settings ] ; then echo "Create 0.settings" cp $script/0.settings $ext_config_dir/0.settings fi if [ ! -d $ext_log_dir ] ; then echo "Create ext log dir" mkdir $ext_log_dir fi if [ ! -d $ext_log_dir/$ext_name ] ; then echo "Create ext log dir for $ext_name" mkdir $ext_log_dir/$ext_name fi if [ ! -d $ext_log_dir/$ext_name/$ext_version ] ; then echo "Create ext log dir for $ext_name $ext_version" mkdir $ext_log_dir/$ext_name/$ext_version fi echo "Change permission of waagent lib dir" chmod -R 600 $waagent_lib_dir ================================================ FILE: script/ovf-env.xml ================================================  1.0 LinuxProvisioningConfiguration test-ext azureuser User@123 false test /home/azureuser/.ssh/authorized_keys 1.0 kms.core.windows.net true Win7_Win8_IaaS_rd_art_stable_140703-0050_GuestAgentPackage.zip ================================================ FILE: script/set_env.sh ================================================ #!/bin/bash # # This script is used to set up a test env for extensions # # Copyright 2014 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # script=$(dirname $0) root=$script/.. cd $root root=`pwd` lib_path="." echo "\$PYTHONPATH=$PYTHONPATH" if [ ! `echo $PYTHONPATH | grep $root` ] ; then lib_path=$lib_path:$root fi if [ $lib_path != "." ] ; then echo "echo \"export PYTHONPATH=\$PYTHONPATH:$lib_path\" >> /etc/bash.bashrc" echo "export PYTHONPATH=\$PYTHONPATH:$lib_path" >> /etc/bash.bashrc echo "Enviroment variable PYTHONPATH has been set." echo "Run \"bash\" to reload bash." else echo "Your enviroment is cool. No action required." fi ================================================ FILE: script/test.crt ================================================ Bag Attributes: subject=/C=ab/ST=ab/L=ab/O=ab/OU=ab/CN=ab/emailAddress=ab issuer=/C=ab/ST=ab/L=ab/O=ab/OU=ab/CN=ab/emailAddress=ab -----BEGIN CERTIFICATE----- MIICOTCCAaICCQD7F0nb+GtpcTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJh YjELMAkGA1UECAwCYWIxCzAJBgNVBAcMAmFiMQswCQYDVQQKDAJhYjELMAkGA1UE CwwCYWIxCzAJBgNVBAMMAmFiMREwDwYJKoZIhvcNAQkBFgJhYjAeFw0xNDA4MDUw ODIwNDZaFw0xNTA4MDUwODIwNDZaMGExCzAJBgNVBAYTAmFiMQswCQYDVQQIDAJh YjELMAkGA1UEBwwCYWIxCzAJBgNVBAoMAmFiMQswCQYDVQQLDAJhYjELMAkGA1UE AwwCYWIxETAPBgkqhkiG9w0BCQEWAmFiMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB iQKBgQC4Vugyj4uAKGYHW/D1eAg1DmLAv01e+9I0zIi8HzJxP87MXmS8EdG5SEzR N6tfQQie76JBSTYI4ngTaVCKx5dVT93LiWxLV193Q3vs/HtwwH1fLq0rAKUhREQ6 +CsRGNyeVfJkNsxAvNvQkectnYuOtcDxX5n/25eWAofobxVbSQIDAQABMA0GCSqG SIb3DQEBCwUAA4GBAF20gkq/DeUSXkZA+jjmmbCPioB3KL63GpoTXfP65d6yU4xZ TlMoLkqGKe3WoXmhjaTOssulgDAGA24IeWy/u7luH+oHdZEmEufFhj4M7tQ1pAhN CT8JCL2dI3F76HD6ZutTOkwRar3PYk5q7RsSJdAemtnwVpgp+RBMtbmct7MQ -----END CERTIFICATE----- ================================================ FILE: script/test.prv ================================================ -----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQC4Vugyj4uAKGYHW/D1eAg1DmLAv01e+9I0zIi8HzJxP87MXmS8 EdG5SEzRN6tfQQie76JBSTYI4ngTaVCKx5dVT93LiWxLV193Q3vs/HtwwH1fLq0r AKUhREQ6+CsRGNyeVfJkNsxAvNvQkectnYuOtcDxX5n/25eWAofobxVbSQIDAQAB AoGAIakE506c238E+m0Id9o+LWn+EFIeT6zN+oQqp6dOr61GFr1ZyZm7YQjZtg5j RZZ7e4Iob6Fts3ufD3RYl67QbBzRwsKwI7sAmzdCmqkopY2H6xv421cEGjkqZIJV 2Xyp9Idji6GfUB6+t1SZDOssbZx3SUkyim0hixK2HCJT4u0CQQDw6rNLZwEmwuhY z1jSERyeTtIcRJ47+Y79tX2xmkyKxZ2Kf28V3Fw/6biCIlmuvxHNhlLijimOME7/ rkqDiscnAkEAw+FpkM96xLlDCqNL2AcNxVnmNyO0Boxw0AKrogfcnDh6S3rD5tZQ IdcIAsEYNjhEJ+/hVCByIUArC885PTzQDwJBAMaDfm3ZWHeKD05uvG+MLhq8NCGa 4Q/mWU7xZ7sau4t1vpTK4MwQoesAOUrx5xg41QCXeGC6Z7+ESvQft8Kgbe0CQAkS OExPf3T6y2MDuvBvKzEXf7TP/3dKK7NGXGJtkMbfSrKSJd5b0GwwxBs0jAV+x5E9 56Z4tjBaA2RRnWn7lfsCQA5SWuDMtlOzyWir09fparnnRL1JFvOwDAHTE0iwS8dO UFHIIw4nqqUYuHb+r/eyRzVtokJ9bSPZOjtTWSVL4W4= -----END RSA PRIVATE KEY----- ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux/Artifacts/CreateUiDefinition.json ================================================ { "handler": "Microsoft.ClassicCompute.VmExtension", "version": "0.0.1-preview", "parameters": { "elements": [ { "name": "fileUris", "type": "Microsoft.Common.FileUpload", "label": "Script files", "toolTip": "The script files that will be downloaded to the virtual machine.", "constraints": { "required": false }, "options": { "multiple": true, "uploadMode": "url" } }, { "name": "commandToExecute", "type": "Microsoft.Common.TextBox", "label": "Command", "defaultValue": "sh script.sh", "toolTip": "The command to execute, for example: sh script.sh", "constraints": { "required": true } } ], "outputs": { "vmName": "[vmName()]", "location": "[location()]", "fileUris": "[elements('fileUris')]", "commandToExecute": "[elements('commandToExecute')]" } } } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux/Artifacts/MainTemplate.json ================================================ { "$schema": "http://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", "contentVersion": "1.0.0.0", "parameters": { "vmName": { "type": "string" }, "location": { "type": "string" }, "fileUris": { "type": "array" }, "commandToExecute": { "type": "string" } }, "resources": [ { "name": "[concat(parameters('vmName'),'/CustomScriptForLinux')]", "type": "Microsoft.ClassicCompute/virtualMachines/extensions", "location": "[parameters('location')]", "apiVersion": "2015-06-01", "properties": { "publisher": "Microsoft.OSTCExtensions", "extension": "CustomScriptForLinux", "version": "1.*", "parameters": { "public": { "fileUris": "[parameters('fileUris')]", "commandToExecute": "[parameters('commandToExecute')]" } } } } ] } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux/Manifest.json ================================================ { "$schema": "https://gallery.azure.com/schemas/2015-04-01/manifest.json#", "name": "custom-script-linux", "publisher": "microsoft", "version": "1.0.0", "displayName": "ms-resource:displayName", "publisherDisplayName": "ms-resource:publisherDisplayName", "publisherLegalName": "ms-resource:publisherDisplayName", "summary": "ms-resource:summary", "longSummary": "ms-resource:summary", "description": "ms-resource:description", "uiDefinition": { "path": "UiDefinition.json" }, "artifacts": [ { "name": "MainTemplate", "type": "Template", "path": "Artifacts\\MainTemplate.json", "isDefault": true }, { "name": "CreateUiDefinition", "type": "Custom", "path": "Artifacts\\CreateUiDefinition.json", "isDefault": false }, ], "icons": { "small": "Icons\\Small.png", "medium": "Icons\\Medium.png", "large": "Icons\\Large.png", "wide": "Icons\\Wide.png" }, "links": [ { "displayName": "ms-resource:link1", "uri": "https://github.com/Azure/azure-linux-extensions/tree/master/CustomScript" } ], "categories": [ "classicCompute-vmextension-linux" ] } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux/Strings/resources.resjson ================================================ { "displayName": "Custom Script For Linux", "publisherDisplayName": "Microsoft Corp.", "summary": "Custom Script extension for Linux", "description": "

CustomScript Extension is a tool to execute your VM customization tasks post VM provision. When this Extension is added to a Virtual Machine, it can download customer’s scripts from the Azure storage or public storage, and execute the scripts on the VM. CustomScript Extension tasks can also be automated using the Azure PowerShell cmdlets and Azure Cross-Platform Command-Line Interface (xPlat CLI).

Legal Terms

By clicking the Create button, I acknowledge that I am getting this software from Microsoft Corp. and that the legal terms of Microsoft Corp. apply to it. Microsoft does not provide rights for third-party software. Also see the privacy statement from Microsoft Corp..

", "link1": "Documentation" } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux/UiDefinition.json ================================================ { "$schema": "https://gallery.azure.com/schemas/2015-02-12/uiDefinition.json#", "createDefinition": { "createBlade": { "name": "AddVmExtension", "extension": "Microsoft_Azure_Classic_Compute" } } } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux-arm/Artifacts/CreateUiDefinition.json ================================================ { "handler": "Microsoft.Compute.VmExtension", "version": "0.0.1-preview", "parameters": { "elements": [ { "name": "fileUris", "type": "Microsoft.Common.FileUpload", "label": "Script files", "toolTip": "The script files that will be downloaded to the virtual machine.", "constraints": { "required": false }, "options": { "multiple": true, "uploadMode": "url" } }, { "name": "commandToExecute", "type": "Microsoft.Common.TextBox", "label": "Command", "defaultValue": "sh script.sh", "toolTip": "The command to execute, for example: sh script.sh", "constraints": { "required": true } } ], "outputs": { "vmName": "[vmName()]", "location": "[location()]", "fileUris": "[elements('fileUris')]", "commandToExecute": "[elements('commandToExecute')]" } } } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux-arm/Artifacts/MainTemplate.json ================================================ { "$schema": "http://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", "contentVersion": "1.0.0.0", "parameters": { "vmName": { "type": "string" }, "location": { "type": "string" }, "fileUris": { "type": "array" }, "commandToExecute": { "type": "string" } }, "resources": [ { "name": "[concat(parameters('vmName'),'/CustomScriptForLinux')]", "type": "Microsoft.Compute/virtualMachines/extensions", "location": "[parameters('location')]", "apiVersion": "2015-06-15", "properties": { "publisher": "Microsoft.OSTCExtensions", "type": "CustomScriptForLinux", "typeHandlerVersion": "1.4", "autoUpgradeMinorVersion": true, "settings": { "fileUris": "[parameters('fileUris')]", "commandToExecute": "[parameters('commandToExecute')]" } } } ] } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux-arm/Manifest.json ================================================ { "$schema": "https://gallery.azure.com/schemas/2015-04-01/manifest.json#", "name": "custom-script-linux-arm", "publisher": "microsoft", "version": "1.0.0", "displayName": "ms-resource:displayName", "publisherDisplayName": "ms-resource:publisherDisplayName", "publisherLegalName": "ms-resource:publisherDisplayName", "summary": "ms-resource:summary", "longSummary": "ms-resource:summary", "description": "ms-resource:description", "uiDefinition": { "path": "UiDefinition.json" }, "artifacts": [ { "name": "MainTemplate", "type": "Template", "path": "Artifacts\\MainTemplate.json", "isDefault": true }, { "name": "CreateUiDefinition", "type": "Custom", "path": "Artifacts\\CreateUiDefinition.json", "isDefault": false }, ], "icons": { "small": "Icons\\Small.png", "medium": "Icons\\Medium.png", "large": "Icons\\Large.png", "wide": "Icons\\Wide.png" }, "links": [ { "displayName": "ms-resource:link1", "uri": "https://github.com/Azure/azure-linux-extensions/tree/master/CustomScript" } ], "categories": [ "compute-vmextension-linux" ] } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux-arm/Strings/resources.resjson ================================================ { "displayName": "Custom Script For Linux", "publisherDisplayName": "Microsoft Corp.", "summary": "Custom Script extension for Linux", "description": "

CustomScript Extension is a tool to execute your VM customization tasks post VM provision. When this Extension is added to a Virtual Machine, it can download customer’s scripts from the Azure storage or public storage, and execute the scripts on the VM. CustomScript Extension tasks can also be automated using the Azure PowerShell cmdlets and Azure Cross-Platform Command-Line Interface (xPlat CLI).

Legal Terms

By clicking the Create button, I acknowledge that I am getting this software from Microsoft Corp. and that the legal terms of Microsoft Corp. apply to it. Microsoft does not provide rights for third-party software. Also see the privacy statement from Microsoft Corp..

", "link1": "Documentation" } ================================================ FILE: ui-extension-packages/microsoft.custom-script-linux-arm/UiDefinition.json ================================================ { "$schema": "https://gallery.azure.com/schemas/2015-02-12/uiDefinition.json#", "createDefinition": { "createBlade": { "name": "AddVmExtension", "extension": "Microsoft_Azure_Compute" } } }